From c5aee561167d955349a27d2a55ad9b9467311309 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Fri, 16 Feb 2024 11:26:52 -0800 Subject: [PATCH] feat: Add native shuffle and columnar shuffle (#30) * feat: Add native shuffle and columnar shuffle * For review --- .../scala/org/apache/comet/CometConf.scala | 108 +- .../shuffle/ArrowReaderIterator.scala | 104 ++ .../CometBlockStoreShuffleReader.scala | 177 +++ .../shuffle/CometShuffleDependency.scala | 59 + .../shuffle/CometShuffledRowRDD.scala | 159 +++ .../shuffle/IpcInputStreamIterator.scala | 127 ++ .../execution/shuffle/ShuffleUtils.scala | 42 + core/Cargo.toml | 4 + core/benches/row_columnar.rs | 109 ++ core/src/execution/datafusion/mod.rs | 1 + core/src/execution/datafusion/planner.rs | 18 + .../execution/datafusion/shuffle_writer.rs | 1257 +++++++++++++++++ core/src/execution/jni_api.rs | 84 +- core/src/execution/mod.rs | 2 + core/src/execution/proto/operator.proto | 9 +- core/src/execution/shuffle/list.rs | 344 +++++ core/src/execution/shuffle/mod.rs | 19 + core/src/execution/shuffle/row.rs | 933 ++++++++++++ core/src/execution/sort.rs | 222 +++ dev/ensure-jars-have-correct-contents.sh | 7 + .../comet/CometShuffleChecksumSupport.java | 50 + .../comet/CometShuffleMemoryAllocator.java | 200 +++ .../shuffle/comet/TooLargePageException.java | 26 + .../sort/CometShuffleExternalSorter.java | 627 ++++++++ .../CometBypassMergeSortShuffleWriter.java | 370 +++++ .../shuffle/CometDiskBlockWriter.java | 451 ++++++ .../shuffle/CometUnsafeShuffleWriter.java | 573 ++++++++ .../shuffle/ExposedByteArrayOutputStream.java | 33 + .../execution/shuffle/ShuffleThreadPool.java | 70 + .../comet/execution/shuffle/SpillInfo.java | 37 + .../comet/execution/shuffle/SpillWriter.java | 233 +++ .../comet/CometSparkSessionExtensions.scala | 109 +- .../main/scala/org/apache/comet/Native.scala | 45 + .../shims/ShimCometShuffleExchangeExec.scala | 39 + .../main/scala/org/apache/spark/Plugins.scala | 8 +- .../spark/shuffle/sort/RowPartition.scala | 42 + .../shuffle/CometShuffleExchangeExec.scala | 381 +++++ .../shuffle/CometShuffleManager.scala | 280 ++++ .../CometSparkSessionExtensionsSuite.scala | 38 + .../comet/exec/CometAggregateSuite.scala | 410 +++--- .../apache/comet/exec/CometExecSuite.scala | 249 +++- .../apache/comet/exec/CometShuffleSuite.scala | 843 +++++++++++ .../org/apache/spark/CometPluginsSuite.scala | 2 +- .../spark/sql/CometTPCDSQuerySuite.scala | 4 + .../spark/sql/CometTPCHQuerySuite.scala | 4 + .../org/apache/spark/sql/CometTestBase.scala | 37 +- .../sql/benchmark/CometExecBenchmark.scala | 10 +- .../sql/benchmark/CometShuffleBenchmark.scala | 609 ++++++++ .../sql/comet/CometPlanStabilitySuite.scala | 5 + 49 files changed, 9313 insertions(+), 257 deletions(-) create mode 100644 common/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/ArrowReaderIterator.scala create mode 100644 common/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometBlockStoreShuffleReader.scala create mode 100644 common/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometShuffleDependency.scala create mode 100644 common/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometShuffledRowRDD.scala create mode 100644 common/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/IpcInputStreamIterator.scala create mode 100644 common/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/ShuffleUtils.scala create mode 100644 core/benches/row_columnar.rs create mode 100644 core/src/execution/datafusion/shuffle_writer.rs create mode 100644 core/src/execution/shuffle/list.rs create mode 100644 core/src/execution/shuffle/mod.rs create mode 100644 core/src/execution/shuffle/row.rs create mode 100644 core/src/execution/sort.rs mode change 100644 => 100755 dev/ensure-jars-have-correct-contents.sh create mode 100644 spark/src/main/java/org/apache/spark/shuffle/comet/CometShuffleChecksumSupport.java create mode 100644 spark/src/main/java/org/apache/spark/shuffle/comet/CometShuffleMemoryAllocator.java create mode 100644 spark/src/main/java/org/apache/spark/shuffle/comet/TooLargePageException.java create mode 100644 spark/src/main/java/org/apache/spark/shuffle/sort/CometShuffleExternalSorter.java create mode 100644 spark/src/main/java/org/apache/spark/sql/comet/execution/shuffle/CometBypassMergeSortShuffleWriter.java create mode 100644 spark/src/main/java/org/apache/spark/sql/comet/execution/shuffle/CometDiskBlockWriter.java create mode 100644 spark/src/main/java/org/apache/spark/sql/comet/execution/shuffle/CometUnsafeShuffleWriter.java create mode 100644 spark/src/main/java/org/apache/spark/sql/comet/execution/shuffle/ExposedByteArrayOutputStream.java create mode 100644 spark/src/main/java/org/apache/spark/sql/comet/execution/shuffle/ShuffleThreadPool.java create mode 100644 spark/src/main/java/org/apache/spark/sql/comet/execution/shuffle/SpillInfo.java create mode 100644 spark/src/main/java/org/apache/spark/sql/comet/execution/shuffle/SpillWriter.java create mode 100644 spark/src/main/scala/org/apache/comet/shims/ShimCometShuffleExchangeExec.scala create mode 100644 spark/src/main/scala/org/apache/spark/shuffle/sort/RowPartition.scala create mode 100644 spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometShuffleExchangeExec.scala create mode 100644 spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometShuffleManager.scala create mode 100644 spark/src/test/scala/org/apache/comet/exec/CometShuffleSuite.scala create mode 100644 spark/src/test/scala/org/apache/spark/sql/benchmark/CometShuffleBenchmark.scala diff --git a/common/src/main/scala/org/apache/comet/CometConf.scala b/common/src/main/scala/org/apache/comet/CometConf.scala index 7f83d9296..a3dc1f38a 100644 --- a/common/src/main/scala/org/apache/comet/CometConf.scala +++ b/common/src/main/scala/org/apache/comet/CometConf.scala @@ -77,7 +77,7 @@ object CometConf { "The amount of additional memory to be allocated per executor process for Comet, in MiB. " + "This config is optional. If this is not specified, it will be set to " + "`spark.comet.memory.overhead.factor` * `spark.executor.memory`. " + - "This is memory that accounts for things like Comet native execution, etc.") + "This is memory that accounts for things like Comet native execution, Comet shuffle, etc.") .bytesConf(ByteUnit.MiB) .createOptional @@ -119,6 +119,112 @@ object CometConf { .booleanConf .createWithDefault(false) + val COMET_EXEC_SHUFFLE_ENABLED: ConfigEntry[Boolean] = + conf(s"$COMET_EXEC_CONFIG_PREFIX.shuffle.enabled") + .doc( + "Whether to enable Comet native shuffle. By default, this config is false. " + + "Note that this requires setting 'spark.shuffle.manager' to " + + "'org.apache.spark.sql.comet.execution.shuffle.CometShuffleManager'. " + + "'spark.shuffle.manager' must be set before starting the Spark application and " + + "cannot be changed during the application.") + .booleanConf + .createWithDefault(false) + + val COMET_COLUMNAR_SHUFFLE_ENABLED: ConfigEntry[Boolean] = conf( + "spark.comet.columnar.shuffle.enabled") + .doc( + "Force Comet to only use columnar shuffle for CometScan and Spark regular operators. " + + "If this is enabled, Comet native shuffle will not be enabled but only Arrow shuffle. " + + "By default, this config is false.") + .booleanConf + .createWithDefault(false) + + val COMET_EXEC_SHUFFLE_CODEC: ConfigEntry[String] = conf( + s"$COMET_EXEC_CONFIG_PREFIX.shuffle.codec") + .doc( + "The codec of Comet native shuffle used to compress shuffle data. Only zstd is supported.") + .stringConf + .createWithDefault("zstd") + + val COMET_COLUMNAR_SHUFFLE_ASYNC_ENABLED: ConfigEntry[Boolean] = conf( + "spark.comet.columnar.shuffle.async.enabled") + .doc( + "Whether to enable asynchronous shuffle for Arrow-based shuffle. By default, this config " + + "is false.") + .booleanConf + .createWithDefault(false) + + val COMET_EXEC_SHUFFLE_ASYNC_THREAD_NUM: ConfigEntry[Int] = + conf("spark.comet.columnar.shuffle.async.thread.num") + .doc("Number of threads used for Comet async columnar shuffle per shuffle task. " + + "By default, this config is 3. Note that more threads means more memory requirement to " + + "buffer shuffle data before flushing to disk. Also, more threads may not always " + + "improve performance, and should be set based on the number of cores available.") + .intConf + .createWithDefault(3) + + val COMET_EXEC_SHUFFLE_ASYNC_MAX_THREAD_NUM: ConfigEntry[Int] = { + conf("spark.comet.columnar.shuffle.async.max.thread.num") + .doc("Maximum number of threads on an executor used for Comet async columnar shuffle. " + + "By default, this config is 100. This is the upper bound of total number of shuffle " + + "threads per executor. In other words, if the number of cores * the number of shuffle " + + "threads per task `spark.comet.columnar.shuffle.async.thread.num` is larger than " + + "this config. Comet will use this config as the number of shuffle threads per " + + "executor instead.") + .intConf + .createWithDefault(100) + } + + val COMET_EXEC_SHUFFLE_SPILL_THRESHOLD: ConfigEntry[Int] = + conf("spark.comet.columnar.shuffle.spill.threshold") + .doc( + "Number of rows to be spilled used for Comet columnar shuffle. " + + "For every configured number of rows, a new spill file will be created. " + + "Higher value means more memory requirement to buffer shuffle data before " + + "flushing to disk. As Comet uses columnar shuffle which is columnar format, " + + "higher value usually helps to improve shuffle data compression ratio. This is " + + "internal config for testing purpose or advanced tuning. By default, " + + "this config is Int.Max.") + .internal() + .intConf + .createWithDefault(Int.MaxValue) + + val COMET_COLUMNAR_SHUFFLE_MEMORY_SIZE: OptionalConfigEntry[Long] = + conf("spark.comet.columnar.shuffle.memorySize") + .doc( + "The optional maximum size of the memory used for Comet columnar shuffle, in MiB. " + + "Note that this config is only used when `spark.comet.columnar.shuffle.enabled` is " + + "true. Once allocated memory size reaches this config, the current batch will be " + + "flushed to disk immediately. If this is not configured, Comet will use " + + "`spark.comet.shuffle.memory.factor` * `spark.comet.memoryOverhead` as " + + "shuffle memory size. If final calculated value is larger than Comet memory " + + "overhead, Comet will use Comet memory overhead as shuffle memory size.") + .bytesConf(ByteUnit.MiB) + .createOptional + + val COMET_COLUMNAR_SHUFFLE_MEMORY_FACTOR: ConfigEntry[Double] = + conf("spark.comet.columnar.shuffle.memory.factor") + .doc( + "Fraction of Comet memory to be allocated per executor process for Comet shuffle. " + + "Comet memory size is specified by `spark.comet.memoryOverhead` or " + + "calculated by `spark.comet.memory.overhead.factor` * `spark.executor.memory`. " + + "By default, this config is 1.0.") + .doubleConf + .checkValue( + factor => factor > 0, + "Ensure that Comet shuffle memory overhead factor is a double greater than 0") + .createWithDefault(1.0) + + val COMET_SHUFFLE_PREFER_DICTIONARY_RATIO: ConfigEntry[Double] = conf( + "spark.comet.shuffle.preferDictionary.ratio") + .doc("The ratio of total values to distinct values in a string column to decide whether to " + + "prefer dictionary encoding when shuffling the column. If the ratio is higher than " + + "this config, dictionary encoding will be used on shuffling string column. This config " + + "is effective if it is higher than 1.0. By default, this config is 10.0. Note that this " + + "config is only used when 'spark.comet.columnar.shuffle.enabled' is true.") + .doubleConf + .createWithDefault(10.0) + val COMET_DEBUG_ENABLED: ConfigEntry[Boolean] = conf("spark.comet.debug.enabled") .doc( diff --git a/common/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/ArrowReaderIterator.scala b/common/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/ArrowReaderIterator.scala new file mode 100644 index 000000000..c17c5bce9 --- /dev/null +++ b/common/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/ArrowReaderIterator.scala @@ -0,0 +1,104 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.spark.sql.comet.execution.shuffle + +import java.nio.channels.ReadableByteChannel + +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.vectorized.ColumnarBatch + +import org.apache.comet.CometConf +import org.apache.comet.vector.{NativeUtil, StreamReader} + +class ArrowReaderIterator(channel: ReadableByteChannel) extends Iterator[ColumnarBatch] { + + private val nativeUtil = new NativeUtil + + private val maxBatchSize = CometConf.COMET_BATCH_SIZE.get(SQLConf.get) + + private val reader = StreamReader(channel) + private var currentIdx = -1 + private var batch = nextBatch() + private var previousBatch: ColumnarBatch = null + private var currentBatch: ColumnarBatch = null + + override def hasNext: Boolean = { + if (batch.isDefined) { + return true + } + + batch = nextBatch() + if (batch.isEmpty) { + return false + } + true + } + + override def next(): ColumnarBatch = { + if (!hasNext) { + throw new NoSuchElementException + } + + val nextBatch = batch.get + val batchRows = nextBatch.numRows() + val numRows = Math.min(batchRows - currentIdx, maxBatchSize) + + // Release the previous sliced batch. + // If it is not released, when closing the reader, arrow library will complain about + // memory leak. + if (currentBatch != null) { + // Close plain arrays in the previous sliced batch. + // The dictionary arrays will be closed when closing the entire batch. + currentBatch.close() + } + + currentBatch = nativeUtil.takeRows(nextBatch, currentIdx, numRows) + currentIdx += numRows + + if (currentIdx == batchRows) { + // We cannot close the batch here, because if there is dictionary array in the batch, + // the dictionary array will be closed immediately, and the returned sliced batch will + // be invalid. + previousBatch = batch.get + + batch = None + currentIdx = -1 + } + + currentBatch + } + + private def nextBatch(): Option[ColumnarBatch] = { + if (previousBatch != null) { + previousBatch.close() + previousBatch = null + } + currentIdx = 0 + reader.nextBatch() + } + + def close(): Unit = + synchronized { + if (currentBatch != null) { + currentBatch.close() + } + reader.close() + } +} diff --git a/common/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometBlockStoreShuffleReader.scala b/common/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometBlockStoreShuffleReader.scala new file mode 100644 index 000000000..b461b53f5 --- /dev/null +++ b/common/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometBlockStoreShuffleReader.scala @@ -0,0 +1,177 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.spark.sql.comet.execution.shuffle + +import java.io.InputStream + +import org.apache.spark.InterruptibleIterator +import org.apache.spark.MapOutputTracker +import org.apache.spark.SparkEnv +import org.apache.spark.TaskContext +import org.apache.spark.internal.Logging +import org.apache.spark.internal.config +import org.apache.spark.io.CompressionCodec +import org.apache.spark.serializer.SerializerManager +import org.apache.spark.shuffle.BaseShuffleHandle +import org.apache.spark.shuffle.ShuffleReader +import org.apache.spark.shuffle.ShuffleReadMetricsReporter +import org.apache.spark.storage.BlockId +import org.apache.spark.storage.BlockManager +import org.apache.spark.storage.BlockManagerId +import org.apache.spark.storage.ShuffleBlockFetcherIterator +import org.apache.spark.util.CompletionIterator + +/** + * Shuffle reader that reads data from the block manager. It reads Arrow-serialized data (IPC + * format) and returns an iterator of ColumnarBatch. + */ +class CometBlockStoreShuffleReader[K, C]( + handle: BaseShuffleHandle[K, _, C], + blocksByAddress: Iterator[(BlockManagerId, scala.collection.Seq[(BlockId, Long, Int)])], + context: TaskContext, + readMetrics: ShuffleReadMetricsReporter, + serializerManager: SerializerManager = SparkEnv.get.serializerManager, + blockManager: BlockManager = SparkEnv.get.blockManager, + mapOutputTracker: MapOutputTracker = SparkEnv.get.mapOutputTracker, + shouldBatchFetch: Boolean = false) + extends ShuffleReader[K, C] + with Logging { + + private val dep = handle.dependency.asInstanceOf[CometShuffleDependency[_, _, _]] + + private def fetchIterator: Iterator[(BlockId, InputStream)] = { + new ShuffleBlockFetcherIterator( + context, + blockManager.blockStoreClient, + blockManager, + mapOutputTracker, + // To tackle Scala issue between Seq and scala.collection.Seq + blocksByAddress.map(pair => (pair._1, pair._2.toSeq)), + (_, inputStream) => { + if (dep.shuffleType == CometColumnarShuffle) { + // Only columnar shuffle supports encryption + serializerManager.wrapForEncryption(inputStream) + } else { + inputStream + } + }, + // Note: we use getSizeAsMb when no suffix is provided for backwards compatibility + SparkEnv.get.conf.get(config.REDUCER_MAX_SIZE_IN_FLIGHT) * 1024 * 1024, + SparkEnv.get.conf.get(config.REDUCER_MAX_REQS_IN_FLIGHT), + SparkEnv.get.conf.get(config.REDUCER_MAX_BLOCKS_IN_FLIGHT_PER_ADDRESS), + SparkEnv.get.conf.get(config.MAX_REMOTE_BLOCK_SIZE_FETCH_TO_MEM), + SparkEnv.get.conf.get(config.SHUFFLE_MAX_ATTEMPTS_ON_NETTY_OOM), + SparkEnv.get.conf.get(config.SHUFFLE_DETECT_CORRUPT), + SparkEnv.get.conf.get(config.SHUFFLE_DETECT_CORRUPT_MEMORY), + SparkEnv.get.conf.get(config.SHUFFLE_CHECKSUM_ENABLED), + SparkEnv.get.conf.get(config.SHUFFLE_CHECKSUM_ALGORITHM), + readMetrics, + fetchContinuousBlocksInBatch).toCompletionIterator + } + + /** Read the combined key-values for this reduce task */ + override def read(): Iterator[Product2[K, C]] = { + val recordIter = fetchIterator + .flatMap { case (_, inputStream) => + var currentReadIterator: ArrowReaderIterator = null + + // Closes last read iterator after the task is finished. + // We need to close read iterator during iterating input streams, + // instead of one callback per read iterator. Otherwise if there are too many + // read iterators, it may blow up the call stack and cause OOM. + context.addTaskCompletionListener[Unit] { _ => + if (currentReadIterator != null) { + currentReadIterator.close() + } + } + + IpcInputStreamIterator(inputStream, decompressingNeeded = true, context) + .flatMap { channel => + if (currentReadIterator != null) { + // Closes previous read iterator. + currentReadIterator.close() + } + currentReadIterator = new ArrowReaderIterator(channel) + currentReadIterator.map((0, _)) // use 0 as key since it's not used + } + } + + // Update the context task metrics for each record read. + val metricIter = CompletionIterator[(Any, Any), Iterator[(Any, Any)]]( + recordIter.map { record => + readMetrics.incRecordsRead(record._2.numRows()) + record + }, + context.taskMetrics().mergeShuffleReadMetrics()) + + // An interruptible iterator must be used here in order to support task cancellation + val interruptibleIter = new InterruptibleIterator[(Any, Any)](context, metricIter) + + val aggregatedIter: Iterator[Product2[K, C]] = if (dep.aggregator.isDefined) { + throw new UnsupportedOperationException("aggregate not allowed") + } else { + interruptibleIter.asInstanceOf[Iterator[Product2[K, C]]] + } + + // Sort the output if there is a sort ordering defined. + val resultIter = dep.keyOrdering match { + case Some(_: Ordering[K]) => + throw new UnsupportedOperationException("order not allowed") + case None => + aggregatedIter + } + + resultIter match { + case _: InterruptibleIterator[Product2[K, C]] => resultIter + case _ => + // Use another interruptible iterator here to support task cancellation as aggregator + // or(and) sorter may have consumed previous interruptible iterator. + new InterruptibleIterator[Product2[K, C]](context, resultIter) + } + } + + private def fetchContinuousBlocksInBatch: Boolean = { + val conf = SparkEnv.get.conf + val serializerRelocatable = dep.serializer.supportsRelocationOfSerializedObjects + val compressed = conf.get(config.SHUFFLE_COMPRESS) + val codecConcatenation = if (compressed) { + CompressionCodec.supportsConcatenationOfSerializedStreams( + CompressionCodec.createCodec(conf)) + } else { + true + } + val useOldFetchProtocol = conf.get(config.SHUFFLE_USE_OLD_FETCH_PROTOCOL) + + // SPARK-34790: Fetching continuous blocks in batch is incompatible with io encryption. + val ioEncryption = conf.get(config.IO_ENCRYPTION_ENABLED) + + val doBatchFetch = shouldBatchFetch && serializerRelocatable && + (!compressed || codecConcatenation) && !useOldFetchProtocol && !ioEncryption + if (shouldBatchFetch && !doBatchFetch) { + logDebug( + "The feature tag of continuous shuffle block fetching is set to true, but " + + "we can not enable the feature because other conditions are not satisfied. " + + s"Shuffle compress: $compressed, serializer relocatable: $serializerRelocatable, " + + s"codec concatenation: $codecConcatenation, use old shuffle fetch protocol: " + + s"$useOldFetchProtocol, io encryption: $ioEncryption.") + } + doBatchFetch + } +} diff --git a/common/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometShuffleDependency.scala b/common/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometShuffleDependency.scala new file mode 100644 index 000000000..7b1d1f127 --- /dev/null +++ b/common/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometShuffleDependency.scala @@ -0,0 +1,59 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.spark.sql.comet.execution.shuffle + +import scala.reflect.ClassTag + +import org.apache.spark.{Aggregator, Partitioner, ShuffleDependency, SparkEnv} +import org.apache.spark.rdd.RDD +import org.apache.spark.serializer.Serializer +import org.apache.spark.shuffle.ShuffleWriteProcessor +import org.apache.spark.sql.types.StructType + +/** + * A [[ShuffleDependency]] that allows us to identify the shuffle dependency as a Comet shuffle. + */ +class CometShuffleDependency[K: ClassTag, V: ClassTag, C: ClassTag]( + @transient private val _rdd: RDD[_ <: Product2[K, V]], + override val partitioner: Partitioner, + override val serializer: Serializer = SparkEnv.get.serializer, + override val keyOrdering: Option[Ordering[K]] = None, + override val aggregator: Option[Aggregator[K, V, C]] = None, + override val mapSideCombine: Boolean = false, + override val shuffleWriterProcessor: ShuffleWriteProcessor = new ShuffleWriteProcessor, + val shuffleType: ShuffleType = CometNativeShuffle, + val schema: Option[StructType] = None) + extends ShuffleDependency[K, V, C]( + _rdd, + partitioner, + serializer, + keyOrdering, + aggregator, + mapSideCombine, + shuffleWriterProcessor) {} + +/** Indicates shuffle type */ +sealed trait ShuffleType + +/** Indicates that the shuffle is performed by Comet native library */ +case object CometNativeShuffle extends ShuffleType + +/** Indicates that the shuffle is performed by Comet JVM class */ +case object CometColumnarShuffle extends ShuffleType diff --git a/common/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometShuffledRowRDD.scala b/common/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometShuffledRowRDD.scala new file mode 100644 index 000000000..af78ed290 --- /dev/null +++ b/common/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometShuffledRowRDD.scala @@ -0,0 +1,159 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.spark.sql.comet.execution.shuffle + +import org.apache.spark.{Dependency, MapOutputTrackerMaster, Partition, Partitioner, ShuffleDependency, SparkEnv, TaskContext} +import org.apache.spark.rdd.RDD +import org.apache.spark.shuffle.sort.SortShuffleManager +import org.apache.spark.sql.execution.{CoalescedMapperPartitionSpec, CoalescedPartitioner, CoalescedPartitionSpec, PartialMapperPartitionSpec, PartialReducerPartitionSpec, ShufflePartitionSpec} +import org.apache.spark.sql.execution.metric.{SQLMetric, SQLShuffleReadMetricsReporter} +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.vectorized.ColumnarBatch + +/** + * Different from [[org.apache.spark.sql.execution.ShuffledRowRDD]], this RDD is specialized for + * reading shuffled data through [[CometBlockStoreShuffleReader]]. The shuffled data is read in an + * iterator of [[Product2[Int, ColumnarBatch]]] instead of [[Product2[Int, InternalRow]]]. + */ +class CometShuffledBatchRDD( + var dependency: ShuffleDependency[Int, _, _], + metrics: Map[String, SQLMetric], + partitionSpecs: Array[ShufflePartitionSpec]) + extends RDD[ColumnarBatch](dependency.rdd.context, Nil) { + + def this(dependency: ShuffleDependency[Int, _, _], metrics: Map[String, SQLMetric]) = { + this( + dependency, + metrics, + Array.tabulate(dependency.partitioner.numPartitions)(i => CoalescedPartitionSpec(i, i + 1))) + } + + dependency.rdd.context.setLocalProperty( + SortShuffleManager.FETCH_SHUFFLE_BLOCKS_IN_BATCH_ENABLED_KEY, + SQLConf.get.fetchShuffleBlocksInBatch.toString) + + override def getDependencies: Seq[Dependency[_]] = List(dependency) + + override val partitioner: Option[Partitioner] = + if (partitionSpecs.forall(_.isInstanceOf[CoalescedPartitionSpec])) { + val indices = partitionSpecs.map(_.asInstanceOf[CoalescedPartitionSpec].startReducerIndex) + // TODO this check is based on assumptions of callers' behavior but is sufficient for now. + if (indices.toSet.size == partitionSpecs.length) { + Some(new CoalescedPartitioner(dependency.partitioner, indices)) + } else { + None + } + } else { + None + } + + override def getPartitions: Array[Partition] = { + Array.tabulate[Partition](partitionSpecs.length) { i => + ShuffledRowRDDPartition(i, partitionSpecs(i)) + } + } + + override def getPreferredLocations(partition: Partition): Seq[String] = { + val tracker = SparkEnv.get.mapOutputTracker.asInstanceOf[MapOutputTrackerMaster] + partition.asInstanceOf[ShuffledRowRDDPartition].spec match { + case CoalescedPartitionSpec(startReducerIndex, endReducerIndex, _) => + // TODO order by partition size. + startReducerIndex.until(endReducerIndex).flatMap { reducerIndex => + tracker.getPreferredLocationsForShuffle(dependency, reducerIndex) + } + + case PartialReducerPartitionSpec(_, startMapIndex, endMapIndex, _) => + tracker.getMapLocation(dependency, startMapIndex, endMapIndex) + + case PartialMapperPartitionSpec(mapIndex, _, _) => + tracker.getMapLocation(dependency, mapIndex, mapIndex + 1) + + case CoalescedMapperPartitionSpec(startMapIndex, endMapIndex, _) => + tracker.getMapLocation(dependency, startMapIndex, endMapIndex) + } + } + + override def compute(split: Partition, context: TaskContext): Iterator[ColumnarBatch] = { + val tempMetrics = context.taskMetrics().createTempShuffleReadMetrics() + // `SQLShuffleReadMetricsReporter` will update its own metrics for SQL exchange operator, + // as well as the `tempMetrics` for basic shuffle metrics. + val sqlMetricsReporter = new SQLShuffleReadMetricsReporter(tempMetrics, metrics) + val reader = split.asInstanceOf[ShuffledRowRDDPartition].spec match { + case CoalescedPartitionSpec(startReducerIndex, endReducerIndex, _) => + SparkEnv.get.shuffleManager + .getReader( + dependency.shuffleHandle, + startReducerIndex, + endReducerIndex, + context, + sqlMetricsReporter) + .asInstanceOf[CometBlockStoreShuffleReader[_, _]] + + case PartialReducerPartitionSpec(reducerIndex, startMapIndex, endMapIndex, _) => + SparkEnv.get.shuffleManager + .getReader( + dependency.shuffleHandle, + startMapIndex, + endMapIndex, + reducerIndex, + reducerIndex + 1, + context, + sqlMetricsReporter) + .asInstanceOf[CometBlockStoreShuffleReader[_, _]] + + case PartialMapperPartitionSpec(mapIndex, startReducerIndex, endReducerIndex) => + SparkEnv.get.shuffleManager + .getReader( + dependency.shuffleHandle, + mapIndex, + mapIndex + 1, + startReducerIndex, + endReducerIndex, + context, + sqlMetricsReporter) + .asInstanceOf[CometBlockStoreShuffleReader[_, _]] + + case CoalescedMapperPartitionSpec(startMapIndex, endMapIndex, numReducers) => + SparkEnv.get.shuffleManager + .getReader( + dependency.shuffleHandle, + startMapIndex, + endMapIndex, + 0, + numReducers, + context, + sqlMetricsReporter) + .asInstanceOf[CometBlockStoreShuffleReader[_, _]] + } + + // TODO: Reads IPC by native code + reader.read().asInstanceOf[Iterator[Product2[Int, ColumnarBatch]]].map(_._2) + } + + override def clearDependencies(): Unit = { + super.clearDependencies() + dependency = null + } +} + +/** + * The [[Partition]] used by [[CometShuffledRowRDD]]. + */ +final case class ShuffledRowRDDPartition(index: Int, spec: ShufflePartitionSpec) extends Partition diff --git a/common/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/IpcInputStreamIterator.scala b/common/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/IpcInputStreamIterator.scala new file mode 100644 index 000000000..281c48108 --- /dev/null +++ b/common/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/IpcInputStreamIterator.scala @@ -0,0 +1,127 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.spark.sql.comet.execution.shuffle + +import java.io.EOFException +import java.io.InputStream +import java.nio.ByteBuffer +import java.nio.ByteOrder +import java.nio.channels.Channels +import java.nio.channels.ReadableByteChannel + +import org.apache.spark.TaskContext +import org.apache.spark.internal.Logging +import org.apache.spark.network.util.LimitedInputStream + +case class IpcInputStreamIterator( + var in: InputStream, + decompressingNeeded: Boolean, + taskContext: TaskContext) + extends Iterator[ReadableByteChannel] + with Logging { + + private[execution] val channel: ReadableByteChannel = if (in != null) { + Channels.newChannel(in) + } else { + null + } + + private val ipcLengthsBuf = ByteBuffer.allocate(8).order(ByteOrder.LITTLE_ENDIAN) + + // NOTE: + // since all ipcs are sharing the same input stream and channel, the second + // hasNext() must be called after the first ipc has been completely processed. + + private[execution] var consumed = true + private var finished = false + private var currentIpcLength = 0L + private var currentLimitedInputStream: LimitedInputStream = _ + + taskContext.addTaskCompletionListener[Unit](_ => { + closeInputStream() + }) + + override def hasNext: Boolean = { + if (in == null || finished) { + return false + } + + // If we've read the length of the next IPC, we don't need to read it again. + if (!consumed) { + return true + } + + if (currentLimitedInputStream != null) { + currentLimitedInputStream.skip(Int.MaxValue) + currentLimitedInputStream = null + } + + // Reads the length of IPC bytes + ipcLengthsBuf.clear() + while (ipcLengthsBuf.hasRemaining && channel.read(ipcLengthsBuf) >= 0) {} + + // If we reach the end of the stream, we are done, or if we read partial length + // then the stream is corrupted. + if (ipcLengthsBuf.hasRemaining) { + if (ipcLengthsBuf.position() == 0) { + finished = true + closeInputStream() + return false + } + throw new EOFException("Data corrupt: unexpected EOF while reading compressed ipc lengths") + } + + ipcLengthsBuf.flip() + currentIpcLength = ipcLengthsBuf.getLong + + // Skips empty IPC + if (currentIpcLength == 0) { + return hasNext + } + consumed = false + return true + } + + override def next(): ReadableByteChannel = { + if (!hasNext) { + throw new NoSuchElementException + } + assert(!consumed) + consumed = true + + val is = new LimitedInputStream(Channels.newInputStream(channel), currentIpcLength, false) + currentLimitedInputStream = is + + if (decompressingNeeded) { + val zs = ShuffleUtils.compressionCodecForShuffling.compressedInputStream(is) + Channels.newChannel(zs) + } else { + Channels.newChannel(is) + } + } + + private def closeInputStream(): Unit = + synchronized { + if (in != null) { + in.close() + in = null + } + } +} diff --git a/common/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/ShuffleUtils.scala b/common/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/ShuffleUtils.scala new file mode 100644 index 000000000..eea134ab5 --- /dev/null +++ b/common/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/ShuffleUtils.scala @@ -0,0 +1,42 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.spark.sql.comet.execution.shuffle + +import org.apache.spark.SparkEnv +import org.apache.spark.internal.Logging +import org.apache.spark.internal.config.IO_COMPRESSION_CODEC +import org.apache.spark.io.CompressionCodec +import org.apache.spark.sql.internal.SQLConf + +import org.apache.comet.CometConf + +private[spark] object ShuffleUtils extends Logging { + lazy val compressionCodecForShuffling: CompressionCodec = { + val sparkConf = SparkEnv.get.conf + val codecName = CometConf.COMET_EXEC_SHUFFLE_CODEC.get(SQLConf.get) + + // only zstd compression is supported at the moment + if (codecName != "zstd") { + logWarning( + s"Overriding config ${IO_COMPRESSION_CODEC}=${codecName} in shuffling, force using zstd") + } + CompressionCodec.createCodec(sparkConf, "zstd") + } +} diff --git a/core/Cargo.toml b/core/Cargo.toml index adc3732e3..d27b83366 100644 --- a/core/Cargo.toml +++ b/core/Cargo.toml @@ -113,3 +113,7 @@ harness = false name = "hash" harness = false +[[bench]] +name = "row_columnar" +harness = false + diff --git a/core/benches/row_columnar.rs b/core/benches/row_columnar.rs new file mode 100644 index 000000000..46f8233c6 --- /dev/null +++ b/core/benches/row_columnar.rs @@ -0,0 +1,109 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use arrow::datatypes::DataType as ArrowDataType; +use comet::execution::shuffle::row::{ + process_sorted_row_partition, SparkUnsafeObject, SparkUnsafeRow, +}; +use criterion::{criterion_group, criterion_main, Criterion}; +use tempfile::Builder; + +const NUM_ROWS: usize = 10000; +const NUM_COLS: usize = 100; +const ROW_SIZE: usize = SparkUnsafeRow::get_row_bitset_width(NUM_COLS) + NUM_COLS * 8; + +fn benchmark(c: &mut Criterion) { + let mut group = c.benchmark_group("row_array_conversion"); + + group.bench_function("row_to_array", |b| { + let spark_rows = (0..NUM_ROWS) + .map(|_| { + let mut spark_row = SparkUnsafeRow::new_with_num_fields(NUM_COLS); + let mut row = Row::new(); + + for i in SparkUnsafeRow::get_row_bitset_width(NUM_COLS)..ROW_SIZE { + row.data[i] = i as u8; + } + + row.to_spark_row(&mut spark_row); + + for i in 0..NUM_COLS { + spark_row.set_not_null_at(i); + } + + spark_row + }) + .collect::>(); + + let mut row_addresses = spark_rows + .iter() + .map(|row| row.get_row_addr()) + .collect::>(); + let mut row_sizes = spark_rows + .iter() + .map(|row| row.get_row_size()) + .collect::>(); + + let row_address_ptr = row_addresses.as_mut_ptr(); + let row_size_ptr = row_sizes.as_mut_ptr(); + let schema = vec![ArrowDataType::Int64; NUM_COLS]; + + b.iter(|| { + let tempfile = Builder::new().tempfile().unwrap(); + + process_sorted_row_partition( + NUM_ROWS, + row_address_ptr, + row_size_ptr, + &schema, + tempfile.path().to_str().unwrap().to_string(), + 1.0, + false, + 0, + None, + ) + .unwrap(); + }); + }); +} + +struct Row { + data: Box<[u8; ROW_SIZE]>, +} + +impl Row { + pub fn new() -> Self { + Row { + data: Box::new([0u8; ROW_SIZE]), + } + } + + pub fn to_spark_row(&self, spark_row: &mut SparkUnsafeRow) { + spark_row.point_to_slice(self.data.as_ref()); + } +} + +fn config() -> Criterion { + Criterion::default() +} + +criterion_group! { + name = benches; + config = config(); + targets = benchmark +} +criterion_main!(benches); diff --git a/core/src/execution/datafusion/mod.rs b/core/src/execution/datafusion/mod.rs index b74ee2638..f9fafeb29 100644 --- a/core/src/execution/datafusion/mod.rs +++ b/core/src/execution/datafusion/mod.rs @@ -20,4 +20,5 @@ mod expressions; mod operators; pub mod planner; +pub(crate) mod shuffle_writer; mod spark_hash; diff --git a/core/src/execution/datafusion/planner.rs b/core/src/execution/datafusion/planner.rs index 1a6aa5ac3..0cd4ace00 100644 --- a/core/src/execution/datafusion/planner.rs +++ b/core/src/execution/datafusion/planner.rs @@ -68,6 +68,7 @@ use crate::{ NormalizeNaNAndZero, }, operators::expand::CometExpandExec, + shuffle_writer::ShuffleWriterExec, }, operators::{CopyExec, ExecutionError, InputBatch, ScanExec}, serde::to_arrow_datatype, @@ -753,6 +754,23 @@ impl PhysicalPlanner { let scan = ScanExec::new(input_batch, fields); Ok((vec![scan.clone()], Arc::new(scan))) } + OpStruct::ShuffleWriter(writer) => { + assert!(children.len() == 1); + let (scans, child) = self.create_plan(&children[0], input_batches)?; + + let partitioning = self + .create_partitioning(writer.partitioning.as_ref().unwrap(), child.schema())?; + + Ok(( + scans, + Arc::new(ShuffleWriterExec::try_new( + child, + partitioning, + writer.output_data_file.clone(), + writer.output_index_file.clone(), + )?), + )) + } OpStruct::Expand(expand) => { assert!(children.len() == 1); let (scans, child) = self.create_plan(&children[0], input_batches)?; diff --git a/core/src/execution/datafusion/shuffle_writer.rs b/core/src/execution/datafusion/shuffle_writer.rs new file mode 100644 index 000000000..7b7911f16 --- /dev/null +++ b/core/src/execution/datafusion/shuffle_writer.rs @@ -0,0 +1,1257 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Defines the External shuffle repartition plan. + +use std::{ + any::Any, + fmt, + fmt::{Debug, Formatter}, + fs::{File, OpenOptions}, + io::{BufReader, BufWriter, Cursor, Read, Seek, SeekFrom, Write}, + path::Path, + sync::Arc, + task::{Context, Poll}, +}; + +use arrow::{datatypes::*, ipc::writer::StreamWriter}; +use async_trait::async_trait; +use bytes::Buf; +use crc32fast::Hasher; +use datafusion::{ + arrow::{ + array::*, + datatypes::{DataType, SchemaRef, TimeUnit}, + error::{ArrowError, Result as ArrowResult}, + record_batch::RecordBatch, + }, + error::{DataFusionError, Result}, + execution::{ + context::TaskContext, + disk_manager::RefCountedTempFile, + memory_pool::{MemoryConsumer, MemoryReservation}, + runtime_env::RuntimeEnv, + }, + physical_plan::{ + expressions::PhysicalSortExpr, + metrics::{BaselineMetrics, Count, ExecutionPlanMetricsSet, MetricBuilder, MetricsSet}, + stream::RecordBatchStreamAdapter, + DisplayAs, DisplayFormatType, ExecutionPlan, Partitioning, RecordBatchStream, + SendableRecordBatchStream, Statistics, + }, +}; +use futures::{lock::Mutex, Stream, StreamExt, TryFutureExt, TryStreamExt}; +use itertools::Itertools; +use simd_adler32::Adler32; +use tokio::task; + +use crate::{ + common::bit::ceil, + execution::datafusion::spark_hash::{create_hashes, pmod}, +}; + +/// The shuffle writer operator maps each input partition to M output partitions based on a +/// partitioning scheme. No guarantees are made about the order of the resulting partitions. +#[derive(Debug)] +pub struct ShuffleWriterExec { + /// Input execution plan + input: Arc, + /// Partitioning scheme to use + partitioning: Partitioning, + /// Output data file path + output_data_file: String, + /// Output index file path + output_index_file: String, + /// Metrics + metrics: ExecutionPlanMetricsSet, +} + +impl DisplayAs for ShuffleWriterExec { + fn fmt_as(&self, t: DisplayFormatType, f: &mut std::fmt::Formatter) -> std::fmt::Result { + match t { + DisplayFormatType::Default | DisplayFormatType::Verbose => { + write!(f, "ShuffleWriterExec: partitioning={:?}", self.partitioning) + } + } + } +} + +#[async_trait] +impl ExecutionPlan for ShuffleWriterExec { + /// Return a reference to Any that can be used for downcasting + fn as_any(&self) -> &dyn Any { + self + } + + /// Get the schema for this execution plan + fn schema(&self) -> SchemaRef { + self.input.schema() + } + + fn output_partitioning(&self) -> Partitioning { + self.partitioning.clone() + } + + fn output_ordering(&self) -> Option<&[PhysicalSortExpr]> { + None + } + + fn children(&self) -> Vec> { + vec![self.input.clone()] + } + + fn with_new_children( + self: Arc, + children: Vec>, + ) -> Result> { + match children.len() { + 1 => Ok(Arc::new(ShuffleWriterExec::try_new( + children[0].clone(), + self.partitioning.clone(), + self.output_data_file.clone(), + self.output_index_file.clone(), + )?)), + _ => panic!("ShuffleWriterExec wrong number of children"), + } + } + + fn execute( + &self, + partition: usize, + context: Arc, + ) -> Result { + let input = self.input.execute(partition, context.clone())?; + let metrics = ShuffleRepartitionerMetrics::new(&self.metrics, 0); + + Ok(Box::pin(RecordBatchStreamAdapter::new( + self.schema(), + futures::stream::once( + external_shuffle( + input, + partition, + self.output_data_file.clone(), + self.output_index_file.clone(), + self.partitioning.clone(), + metrics, + context, + ) + .map_err(|e| ArrowError::ExternalError(Box::new(e))), + ) + .try_flatten(), + ))) + } + + fn metrics(&self) -> Option { + Some(self.metrics.clone_inner()) + } + + fn statistics(&self) -> Result { + self.input.statistics() + } +} + +impl ShuffleWriterExec { + /// Create a new ShuffleWriterExec + pub fn try_new( + input: Arc, + partitioning: Partitioning, + output_data_file: String, + output_index_file: String, + ) -> Result { + Ok(ShuffleWriterExec { + input, + partitioning, + metrics: ExecutionPlanMetricsSet::new(), + output_data_file, + output_index_file, + }) + } +} + +struct PartitionBuffer { + /// The schema of batches to be partitioned. + schema: SchemaRef, + /// The "frozen" Arrow IPC bytes of active data. They are frozen when `flush` is called. + frozen: Vec, + /// Array builders for appending rows into buffering batches. + active: Vec>, + /// The estimation of memory size of active builders in bytes when they are filled. + active_slots_mem_size: usize, + /// Number of rows in active builders. + num_active_rows: usize, + /// The maximum number of rows in a batch. Once `num_active_rows` reaches `batch_size`, + /// the active array builders will be frozen and appended to frozen buffer `frozen`. + batch_size: usize, +} + +impl PartitionBuffer { + fn new(schema: SchemaRef, batch_size: usize) -> Self { + Self { + schema, + frozen: vec![], + active: vec![], + active_slots_mem_size: 0, + num_active_rows: 0, + batch_size, + } + } + + /// Initializes active builders if necessary. + fn init_active_if_necessary(&mut self) -> Result { + let mut mem_diff = 0; + + if self.active.is_empty() { + self.active = new_array_builders(&self.schema, self.batch_size); + if self.active_slots_mem_size == 0 { + self.active_slots_mem_size = self + .active + .iter() + .zip(self.schema.fields()) + .map(|(_ab, field)| slot_size(self.batch_size, field.data_type())) + .sum::(); + } + mem_diff += self.active_slots_mem_size as isize; + } + Ok(mem_diff) + } + + /// Appends all rows of given batch into active array builders. + fn append_batch(&mut self, batch: &RecordBatch) -> Result { + let columns = batch.columns(); + let indices = (0..batch.num_rows()).collect::>(); + self.append_rows(columns, &indices) + } + + /// Appends rows of specified indices from columns into active array builders. + fn append_rows(&mut self, columns: &[ArrayRef], indices: &[usize]) -> Result { + let mut mem_diff = 0; + let mut start = 0; + + // lazy init because some partition may be empty + mem_diff += self.init_active_if_necessary()?; + + while start < indices.len() { + let end = (start + self.batch_size).min(indices.len()); + self.active + .iter_mut() + .zip(columns) + .for_each(|(builder, column)| { + append_columns(builder, column, &indices[start..end], column.data_type()); + }); + self.num_active_rows += end - start; + if self.num_active_rows >= self.batch_size { + mem_diff += self.flush()?; + mem_diff += self.init_active_if_necessary()?; + } + start = end; + } + Ok(mem_diff) + } + + /// flush active data into frozen bytes + fn flush(&mut self) -> Result { + if self.num_active_rows == 0 { + return Ok(0); + } + let mut mem_diff = 0isize; + + // active -> staging + let active = std::mem::take(&mut self.active); + self.num_active_rows = 0; + mem_diff -= self.active_slots_mem_size as isize; + + let frozen_batch = make_batch(self.schema.clone(), active)?; + + let frozen_capacity_old = self.frozen.capacity(); + let mut cursor = Cursor::new(&mut self.frozen); + cursor.seek(SeekFrom::End(0))?; + write_ipc_compressed(&frozen_batch, &mut cursor)?; + + mem_diff += (self.frozen.capacity() - frozen_capacity_old) as isize; + Ok(mem_diff) + } +} + +fn slot_size(len: usize, data_type: &DataType) -> usize { + match data_type { + DataType::Boolean => ceil(len, 8), + DataType::Int8 => len, + DataType::Int16 => len * 2, + DataType::Int32 => len * 4, + DataType::Int64 => len * 8, + DataType::UInt8 => len, + DataType::UInt16 => len * 2, + DataType::UInt32 => len * 4, + DataType::UInt64 => len * 8, + DataType::Float32 => len * 4, + DataType::Float64 => len * 8, + DataType::Date32 => len * 4, + DataType::Date64 => len * 8, + DataType::Time32(TimeUnit::Second) => len * 4, + DataType::Time32(TimeUnit::Millisecond) => len * 4, + DataType::Time64(TimeUnit::Microsecond) => len * 8, + DataType::Time64(TimeUnit::Nanosecond) => len * 8, + // TODO: this is not accurate, but should be good enough for now + DataType::Utf8 => len * 100 + len * 4, + DataType::LargeUtf8 => len * 100 + len * 8, + DataType::Decimal128(_, _) => len * 16, + DataType::Dictionary(key_type, value_type) => { + // TODO: this is not accurate, but should be good enough for now + slot_size(len, key_type.as_ref()) + slot_size(len / 10, value_type.as_ref()) + } + DataType::FixedSizeBinary(s) => len * (*s as usize), + DataType::Timestamp(_, _) => len * 8, + dt => unimplemented!( + "{}", + format!("data type {dt} not supported in shuffle write") + ), + } +} + +fn append_columns( + to: &mut Box, + from: &Arc, + indices: &[usize], + data_type: &DataType, +) { + /// Append values from `from` to `to` using `indices`. + macro_rules! append { + ($arrowty:ident) => {{ + type B = paste::paste! {[< $arrowty Builder >]}; + type A = paste::paste! {[< $arrowty Array >]}; + let t = to.as_any_mut().downcast_mut::().unwrap(); + let f = from.as_any().downcast_ref::().unwrap(); + for &i in indices { + if f.is_valid(i) { + t.append_value(f.value(i)); + } else { + t.append_null(); + } + } + }}; + } + + /// Some array builder (e.g. `FixedSizeBinary`) its `append_value` method returning + /// a `Result`. + macro_rules! append_unwrap { + ($arrowty:ident) => {{ + type B = paste::paste! {[< $arrowty Builder >]}; + type A = paste::paste! {[< $arrowty Array >]}; + let t = to.as_any_mut().downcast_mut::().unwrap(); + let f = from.as_any().downcast_ref::().unwrap(); + for &i in indices { + if f.is_valid(i) { + t.append_value(f.value(i)).unwrap(); + } else { + t.append_null(); + } + } + }}; + } + + /// Appends values from a dictionary array to a dictionary builder. + macro_rules! append_dict { + ($kt:ty, $builder:ty, $dict_array:ty) => {{ + let t = to.as_any_mut().downcast_mut::<$builder>().unwrap(); + let f = from + .as_any() + .downcast_ref::>() + .unwrap() + .downcast_dict::<$dict_array>() + .unwrap(); + for &i in indices { + if f.is_valid(i) { + t.append_value(f.value(i)); + } else { + t.append_null(); + } + } + }}; + } + + macro_rules! append_dict_helper { + ($kt:ident, $ty:ty, $dict_array:ty) => {{ + match $kt.as_ref() { + DataType::Int8 => append_dict!(Int8Type, PrimitiveDictionaryBuilder, $dict_array), + DataType::Int16 => append_dict!(Int16Type, PrimitiveDictionaryBuilder, $dict_array), + DataType::Int32 => append_dict!(Int32Type, PrimitiveDictionaryBuilder, $dict_array), + DataType::Int64 => append_dict!(Int64Type, PrimitiveDictionaryBuilder, $dict_array), + DataType::UInt8 => append_dict!(UInt8Type, PrimitiveDictionaryBuilder, $dict_array), + DataType::UInt16 => { + append_dict!(UInt16Type, PrimitiveDictionaryBuilder, $dict_array) + } + DataType::UInt32 => { + append_dict!(UInt32Type, PrimitiveDictionaryBuilder, $dict_array) + } + DataType::UInt64 => { + append_dict!(UInt64Type, PrimitiveDictionaryBuilder, $dict_array) + } + _ => unreachable!("Unknown key type for dictionary"), + } + }}; + } + + macro_rules! primitive_append_dict_helper { + ($kt:ident, $vt:ident) => { + match $vt.as_ref() { + DataType::Int8 => { + append_dict_helper!($kt, Int8Type, Int8Array) + } + DataType::Int16 => { + append_dict_helper!($kt, Int16Type, Int16Array) + } + DataType::Int32 => { + append_dict_helper!($kt, Int32Type, Int32Array) + } + DataType::Int64 => { + append_dict_helper!($kt, Int64Type, Int64Array) + } + DataType::UInt8 => { + append_dict_helper!($kt, UInt8Type, UInt8Array) + } + DataType::UInt16 => { + append_dict_helper!($kt, UInt16Type, UInt16Array) + } + DataType::UInt32 => { + append_dict_helper!($kt, UInt32Type, UInt32Array) + } + DataType::UInt64 => { + append_dict_helper!($kt, UInt64Type, UInt64Array) + } + DataType::Float32 => { + append_dict_helper!($kt, Float32Type, Float32Array) + } + DataType::Float64 => { + append_dict_helper!($kt, Float64Type, Float64Array) + } + DataType::Decimal128(_, _) => { + append_dict_helper!($kt, Decimal128Type, Decimal128Array) + } + DataType::Timestamp(TimeUnit::Microsecond, _) => { + append_dict_helper!($kt, TimestampMicrosecondType, TimestampMicrosecondArray) + } + DataType::Date32 => { + append_dict_helper!($kt, Date32Type, Date32Array) + } + DataType::Date64 => { + append_dict_helper!($kt, Date64Type, Date64Array) + } + t => unimplemented!("{:?} is not supported for appending dictionary builder", t), + } + }; + } + + macro_rules! append_string_dict { + ($kt:ident) => {{ + match $kt.as_ref() { + DataType::Int8 => { + append_dict!(Int8Type, StringDictionaryBuilder, StringArray) + } + DataType::Int16 => { + append_dict!(Int16Type, StringDictionaryBuilder, StringArray) + } + DataType::Int32 => { + append_dict!(Int32Type, StringDictionaryBuilder, StringArray) + } + DataType::Int64 => { + append_dict!(Int64Type, StringDictionaryBuilder, StringArray) + } + DataType::UInt8 => { + append_dict!(UInt8Type, StringDictionaryBuilder, StringArray) + } + DataType::UInt16 => { + append_dict!(UInt16Type, StringDictionaryBuilder, StringArray) + } + DataType::UInt32 => { + append_dict!(UInt32Type, StringDictionaryBuilder, StringArray) + } + DataType::UInt64 => { + append_dict!(UInt64Type, StringDictionaryBuilder, StringArray) + } + _ => unreachable!("Unknown key type for dictionary"), + } + }}; + } + + match data_type { + DataType::Boolean => append!(Boolean), + DataType::Int8 => append!(Int8), + DataType::Int16 => append!(Int16), + DataType::Int32 => append!(Int32), + DataType::Int64 => append!(Int64), + DataType::UInt8 => append!(UInt8), + DataType::UInt16 => append!(UInt16), + DataType::UInt32 => append!(UInt32), + DataType::UInt64 => append!(UInt64), + DataType::Float32 => append!(Float32), + DataType::Float64 => append!(Float64), + DataType::Date32 => append!(Date32), + DataType::Date64 => append!(Date64), + DataType::Time32(TimeUnit::Second) => append!(Time32Second), + DataType::Time32(TimeUnit::Millisecond) => append!(Time32Millisecond), + DataType::Time64(TimeUnit::Microsecond) => append!(Time64Microsecond), + DataType::Time64(TimeUnit::Nanosecond) => append!(Time64Nanosecond), + DataType::Timestamp(TimeUnit::Microsecond, _) => { + append!(TimestampMicrosecond) + } + DataType::Utf8 => append!(String), + DataType::LargeUtf8 => append!(LargeString), + DataType::Decimal128(_, _) => append!(Decimal128), + DataType::Dictionary(key_type, value_type) if value_type.is_primitive() => { + primitive_append_dict_helper!(key_type, value_type) + } + DataType::Dictionary(key_type, value_type) + if matches!(value_type.as_ref(), DataType::Utf8) => + { + append_string_dict!(key_type) + } + DataType::FixedSizeBinary(_) => append_unwrap!(FixedSizeBinary), + t => unimplemented!( + "{}", + format!("data type {} not supported in shuffle write", t) + ), + } +} + +struct SpillInfo { + file: RefCountedTempFile, + offsets: Vec, +} + +struct ShuffleRepartitioner { + output_data_file: String, + output_index_file: String, + schema: SchemaRef, + buffered_partitions: Mutex>, + spills: Mutex>, + /// Sort expressions + /// Partitioning scheme to use + partitioning: Partitioning, + num_output_partitions: usize, + runtime: Arc, + metrics: ShuffleRepartitionerMetrics, + reservation: MemoryReservation, + /// Hashes for each row in the current batch + hashes_buf: Vec, + /// Partition ids for each row in the current batch + partition_ids: Vec, +} + +struct ShuffleRepartitionerMetrics { + /// metrics + baseline: BaselineMetrics, + + /// count of spills during the execution of the operator + spill_count: Count, + + /// total spilled bytes during the execution of the operator + spilled_bytes: Count, +} + +impl ShuffleRepartitionerMetrics { + fn new(metrics: &ExecutionPlanMetricsSet, partition: usize) -> Self { + Self { + baseline: BaselineMetrics::new(metrics, partition), + spill_count: MetricBuilder::new(metrics).spill_count(partition), + spilled_bytes: MetricBuilder::new(metrics).spilled_bytes(partition), + } + } +} + +impl ShuffleRepartitioner { + #[allow(clippy::too_many_arguments)] + pub fn new( + partition_id: usize, + output_data_file: String, + output_index_file: String, + schema: SchemaRef, + partitioning: Partitioning, + metrics: ShuffleRepartitionerMetrics, + runtime: Arc, + batch_size: usize, + ) -> Self { + let num_output_partitions = partitioning.partition_count(); + let reservation = MemoryConsumer::new(format!("ShuffleRepartitioner[{}]", partition_id)) + .with_can_spill(true) + .register(&runtime.memory_pool); + + let mut hashes_buf = Vec::with_capacity(batch_size); + let mut partition_ids = Vec::with_capacity(batch_size); + + // Safety: `hashes_buf` will be filled with valid values before being used. + // `partition_ids` will be filled with valid values before being used. + unsafe { + hashes_buf.set_len(batch_size); + partition_ids.set_len(batch_size); + } + + Self { + output_data_file, + output_index_file, + schema: schema.clone(), + buffered_partitions: Mutex::new( + (0..num_output_partitions) + .map(|_| PartitionBuffer::new(schema.clone(), batch_size)) + .collect::>(), + ), + spills: Mutex::new(vec![]), + partitioning, + num_output_partitions, + runtime, + metrics, + reservation, + hashes_buf, + partition_ids, + } + } + + /// Shuffles rows in input batch into corresponding partition buffer. + /// This function first calculates hashes for rows and then takes rows in same + /// partition as a record batch which is appended into partition buffer. + async fn insert_batch(&mut self, input: RecordBatch) -> Result<()> { + if input.num_rows() == 0 { + // skip empty batch + return Ok(()); + } + let _timer = self.metrics.baseline.elapsed_compute().timer(); + + // NOTE: in shuffle writer exec, the output_rows metrics represents the + // number of rows those are written to output data file. + self.metrics.baseline.record_output(input.num_rows()); + + let num_output_partitions = self.num_output_partitions; + match &self.partitioning { + Partitioning::Hash(exprs, _) => { + let arrays = exprs + .iter() + .map(|expr| expr.evaluate(&input)?.into_array(input.num_rows())) + .collect::>>()?; + + // use identical seed as spark hash partition + let hashes_buf = &mut self.hashes_buf[..arrays[0].len()]; + hashes_buf.fill(42_u32); + + // Hash arrays and compute buckets based on number of partitions + let partition_ids = &mut self.partition_ids[..arrays[0].len()]; + create_hashes(&arrays, hashes_buf)? + .iter() + .enumerate() + .for_each(|(idx, hash)| { + partition_ids[idx] = pmod(*hash, num_output_partitions) as u64 + }); + + // count each partition size + let mut partition_counters = vec![0usize; num_output_partitions]; + partition_ids + .iter() + .for_each(|partition_id| partition_counters[*partition_id as usize] += 1); + + // accumulate partition counters into partition ends + // e.g. partition counter: [1, 3, 2, 1] => [1, 4, 6, 7] + let mut partition_ends = partition_counters; + let mut accum = 0; + partition_ends.iter_mut().for_each(|v| { + *v += accum; + accum = *v; + }); + + // calculate shuffled partition ids + // e.g. partition ids: [3, 1, 1, 1, 2, 2, 0] => [6, 1, 2, 3, 4, 5, 0] which is the + // row indices for rows ordered by their partition id. For example, first partition + // 0 has one row index [6], partition 1 has row indices [1, 2, 3], etc. + let mut shuffled_partition_ids = vec![0usize; input.num_rows()]; + for (index, partition_id) in partition_ids.iter().enumerate().rev() { + partition_ends[*partition_id as usize] -= 1; + let end = partition_ends[*partition_id as usize]; + shuffled_partition_ids[end] = index; + } + + // after calculating, partition ends become partition starts + let mut partition_starts = partition_ends; + partition_starts.push(input.num_rows()); + + let mut mem_diff = 0; + // For each interval of row indices of partition, taking rows from input batch and + // appending into output buffer. + for (partition_id, (&start, &end)) in partition_starts + .iter() + .tuple_windows() + .enumerate() + .filter(|(_, (start, end))| start < end) + { + let mut buffered_partitions = self.buffered_partitions.lock().await; + let output = &mut buffered_partitions[partition_id]; + + // If the range of indices is not big enough, just appending the rows into + // active array builders instead of directly adding them as a record batch. + mem_diff += + output.append_rows(input.columns(), &shuffled_partition_ids[start..end])?; + } + + if mem_diff > 0 { + let mem_increase = mem_diff as usize; + if self.reservation.try_grow(mem_increase).is_err() { + self.spill().await?; + self.reservation.free(); + self.reservation.try_grow(mem_increase)?; + } + } + if mem_diff < 0 { + let mem_used = self.reservation.size(); + let mem_decrease = mem_used.min(-mem_diff as usize); + self.reservation.shrink(mem_decrease); + } + } + Partitioning::UnknownPartitioning(n) if *n == 1 => { + let mut buffered_partitions = self.buffered_partitions.lock().await; + + assert!( + buffered_partitions.len() == 1, + "Expected 1 partition but got {}", + buffered_partitions.len() + ); + + let output = &mut buffered_partitions[0]; + output.append_batch(&input)?; + } + other => { + // this should be unreachable as long as the validation logic + // in the constructor is kept up-to-date + return Err(DataFusionError::NotImplemented(format!( + "Unsupported repartitioning scheme {:?}", + other + ))); + } + } + Ok(()) + } + + /// Writes buffered shuffled record batches into Arrow IPC bytes. + async fn shuffle_write(&mut self) -> Result { + let _timer = self.metrics.baseline.elapsed_compute().timer(); + let num_output_partitions = self.num_output_partitions; + let mut buffered_partitions = self.buffered_partitions.lock().await; + let mut output_batches: Vec> = vec![vec![]; num_output_partitions]; + + for i in 0..num_output_partitions { + buffered_partitions[i].flush()?; + output_batches[i] = std::mem::take(&mut buffered_partitions[i].frozen); + } + + let mut spills = self.spills.lock().await; + let output_spills = spills.drain(..).collect::>(); + + let data_file = self.output_data_file.clone(); + let index_file = self.output_index_file.clone(); + + let mut offsets = vec![0; num_output_partitions + 1]; + let mut output_data = OpenOptions::new() + .write(true) + .create(true) + .truncate(true) + .open(data_file) + .map_err(|e| DataFusionError::Execution(format!("shuffle write error: {:?}", e)))?; + + for i in 0..num_output_partitions { + offsets[i] = output_data.stream_position()?; + output_data.write_all(&output_batches[i])?; + output_batches[i].clear(); + + // append partition in each spills + for spill in &output_spills { + let length = spill.offsets[i + 1] - spill.offsets[i]; + if length > 0 { + let mut spill_file = + BufReader::new(File::open(spill.file.path()).map_err(|e| { + DataFusionError::Execution(format!("shuffle write error: {:?}", e)) + })?); + spill_file.seek(SeekFrom::Start(spill.offsets[i]))?; + std::io::copy(&mut spill_file.take(length), &mut output_data).map_err(|e| { + DataFusionError::Execution(format!("shuffle write error: {:?}", e)) + })?; + } + } + } + output_data.flush()?; + + // add one extra offset at last to ease partition length computation + offsets[num_output_partitions] = output_data + .stream_position() + .map_err(|e| DataFusionError::Execution(format!("shuffle write error: {:?}", e)))?; + let mut output_index = + BufWriter::new(File::create(index_file).map_err(|e| { + DataFusionError::Execution(format!("shuffle write error: {:?}", e)) + })?); + for offset in offsets { + output_index + .write_all(&(offset as i64).to_le_bytes()[..]) + .map_err(|e| DataFusionError::Execution(format!("shuffle write error: {:?}", e)))?; + } + output_index.flush()?; + + let used = self.reservation.size(); + self.reservation.shrink(used); + + // shuffle writer always has empty output + Ok(Box::pin(EmptyStream::try_new(self.schema.clone())?)) + } + + fn used(&self) -> usize { + self.reservation.size() + } + + fn spilled_bytes(&self) -> usize { + self.metrics.spilled_bytes.value() + } + + fn spill_count(&self) -> usize { + self.metrics.spill_count.value() + } + + async fn spill(&self) -> Result { + log::debug!( + "ShuffleRepartitioner spilling shuffle data of {} to disk while inserting ({} time(s) so far)", + self.used(), + self.spill_count() + ); + + let mut buffered_partitions = self.buffered_partitions.lock().await; + // we could always get a chance to free some memory as long as we are holding some + if buffered_partitions.len() == 0 { + return Ok(0); + } + + let spillfile = self + .runtime + .disk_manager + .create_tmp_file("shuffle writer spill")?; + let offsets = spill_into( + &mut buffered_partitions, + spillfile.path(), + self.num_output_partitions, + ) + .await?; + + let mut spills = self.spills.lock().await; + let used = self.reservation.size(); + self.metrics.spill_count.add(1); + self.metrics.spilled_bytes.add(used); + spills.push(SpillInfo { + file: spillfile, + offsets, + }); + Ok(used) + } +} + +/// consume the `buffered_partitions` and do spill into a single temp shuffle output file +async fn spill_into( + buffered_partitions: &mut [PartitionBuffer], + path: &Path, + num_output_partitions: usize, +) -> Result> { + let mut output_batches: Vec> = vec![vec![]; num_output_partitions]; + + for i in 0..num_output_partitions { + buffered_partitions[i].flush()?; + output_batches[i] = std::mem::take(&mut buffered_partitions[i].frozen); + } + let path = path.to_owned(); + + task::spawn_blocking(move || { + let mut offsets = vec![0; num_output_partitions + 1]; + let mut spill_data = OpenOptions::new() + .write(true) + .create(true) + .truncate(true) + .open(path)?; + + for i in 0..num_output_partitions { + offsets[i] = spill_data.stream_position()?; + spill_data.write_all(&output_batches[i])?; + output_batches[i].clear(); + } + // add one extra offset at last to ease partition length computation + offsets[num_output_partitions] = spill_data.stream_position()?; + Ok(offsets) + }) + .await + .map_err(|e| DataFusionError::Execution(format!("Error occurred while spilling {}", e)))? +} + +impl Debug for ShuffleRepartitioner { + fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { + f.debug_struct("ShuffleRepartitioner") + .field("memory_used", &self.used()) + .field("spilled_bytes", &self.spilled_bytes()) + .field("spilled_count", &self.spill_count()) + .finish() + } +} + +async fn external_shuffle( + mut input: SendableRecordBatchStream, + partition_id: usize, + output_data_file: String, + output_index_file: String, + partitioning: Partitioning, + metrics: ShuffleRepartitionerMetrics, + context: Arc, +) -> Result { + let schema = input.schema(); + let mut repartitioner = ShuffleRepartitioner::new( + partition_id, + output_data_file, + output_index_file, + schema.clone(), + partitioning, + metrics, + context.runtime_env(), + context.session_config().batch_size(), + ); + + while let Some(batch) = input.next().await { + let batch = batch?; + repartitioner.insert_batch(batch).await?; + } + repartitioner.shuffle_write().await +} + +fn new_array_builders(schema: &SchemaRef, batch_size: usize) -> Vec> { + schema + .fields() + .iter() + .map(|field| { + let dt = field.data_type(); + if matches!(dt, DataType::Dictionary(_, _)) { + make_dict_builder(dt, batch_size) + } else { + make_builder(dt, batch_size) + } + }) + .collect::>() +} + +macro_rules! primitive_dict_builder_inner_helper { + ($kt:ty, $vt:ty, $capacity:ident) => { + Box::new(PrimitiveDictionaryBuilder::<$kt, $vt>::with_capacity( + $capacity, + $capacity / 100, + )) + }; +} + +macro_rules! primitive_dict_builder_helper { + ($kt:ty, $vt:ident, $capacity:ident) => { + match $vt.as_ref() { + DataType::Int8 => { + primitive_dict_builder_inner_helper!($kt, Int8Type, $capacity) + } + DataType::Int16 => { + primitive_dict_builder_inner_helper!($kt, Int16Type, $capacity) + } + DataType::Int32 => { + primitive_dict_builder_inner_helper!($kt, Int32Type, $capacity) + } + DataType::Int64 => { + primitive_dict_builder_inner_helper!($kt, Int64Type, $capacity) + } + DataType::UInt8 => { + primitive_dict_builder_inner_helper!($kt, UInt8Type, $capacity) + } + DataType::UInt16 => { + primitive_dict_builder_inner_helper!($kt, UInt16Type, $capacity) + } + DataType::UInt32 => { + primitive_dict_builder_inner_helper!($kt, UInt32Type, $capacity) + } + DataType::UInt64 => { + primitive_dict_builder_inner_helper!($kt, UInt64Type, $capacity) + } + DataType::Float32 => { + primitive_dict_builder_inner_helper!($kt, Float32Type, $capacity) + } + DataType::Float64 => { + primitive_dict_builder_inner_helper!($kt, Float64Type, $capacity) + } + DataType::Decimal128(p, s) => { + let keys_builder = PrimitiveBuilder::<$kt>::new(); + let values_builder = + Decimal128Builder::new().with_data_type(DataType::Decimal128(*p, *s)); + Box::new( + PrimitiveDictionaryBuilder::<$kt, Decimal128Type>::new_from_empty_builders( + keys_builder, + values_builder, + ), + ) + } + DataType::Timestamp(TimeUnit::Microsecond, timezone) => { + let keys_builder = PrimitiveBuilder::<$kt>::new(); + let values_builder = TimestampMicrosecondBuilder::new() + .with_data_type(DataType::Timestamp(TimeUnit::Microsecond, timezone.clone())); + Box::new( + PrimitiveDictionaryBuilder::<$kt, TimestampMicrosecondType>::new_from_empty_builders( + keys_builder, + values_builder, + ), + ) + } + DataType::Date32 => { + primitive_dict_builder_inner_helper!($kt, Date32Type, $capacity) + } + DataType::Date64 => { + primitive_dict_builder_inner_helper!($kt, Date64Type, $capacity) + } + t => unimplemented!("{:?} is not supported", t), + } + }; +} + +macro_rules! string_dict_builder_inner_helper { + ($kt:ty, $capacity:ident, $builder:ident) => { + Box::new($builder::<$kt>::with_capacity( + $capacity, + $capacity / 100, + $capacity, + )) + }; +} + +/// Returns a dictionary array builder with capacity `capacity` that corresponds to the datatype +/// `DataType` This function is useful to construct arrays from an arbitrary vectors with +/// known/expected schema. +/// TODO: move this to the upstream. +fn make_dict_builder(datatype: &DataType, capacity: usize) -> Box { + match datatype { + DataType::Dictionary(key_type, value_type) if value_type.is_primitive() => { + match key_type.as_ref() { + DataType::Int8 => primitive_dict_builder_helper!(Int8Type, value_type, capacity), + DataType::Int16 => primitive_dict_builder_helper!(Int16Type, value_type, capacity), + DataType::Int32 => primitive_dict_builder_helper!(Int32Type, value_type, capacity), + DataType::Int64 => primitive_dict_builder_helper!(Int64Type, value_type, capacity), + DataType::UInt8 => primitive_dict_builder_helper!(UInt8Type, value_type, capacity), + DataType::UInt16 => { + primitive_dict_builder_helper!(UInt16Type, value_type, capacity) + } + DataType::UInt32 => { + primitive_dict_builder_helper!(UInt32Type, value_type, capacity) + } + DataType::UInt64 => { + primitive_dict_builder_helper!(UInt64Type, value_type, capacity) + } + _ => unreachable!(""), + } + } + DataType::Dictionary(key_type, value_type) + if matches!(value_type.as_ref(), DataType::Utf8) => + { + match key_type.as_ref() { + DataType::Int8 => { + string_dict_builder_inner_helper!(Int16Type, capacity, StringDictionaryBuilder) + } + DataType::Int16 => { + string_dict_builder_inner_helper!(Int16Type, capacity, StringDictionaryBuilder) + } + DataType::Int32 => { + string_dict_builder_inner_helper!(Int32Type, capacity, StringDictionaryBuilder) + } + DataType::Int64 => { + string_dict_builder_inner_helper!(Int64Type, capacity, StringDictionaryBuilder) + } + DataType::UInt8 => { + string_dict_builder_inner_helper!(UInt8Type, capacity, StringDictionaryBuilder) + } + DataType::UInt16 => { + string_dict_builder_inner_helper!(UInt16Type, capacity, StringDictionaryBuilder) + } + DataType::UInt32 => { + string_dict_builder_inner_helper!(UInt32Type, capacity, StringDictionaryBuilder) + } + DataType::UInt64 => { + string_dict_builder_inner_helper!(UInt64Type, capacity, StringDictionaryBuilder) + } + _ => unreachable!(""), + } + } + DataType::Dictionary(key_type, value_type) + if matches!(value_type.as_ref(), DataType::LargeUtf8) => + { + match key_type.as_ref() { + DataType::Int8 => string_dict_builder_inner_helper!( + Int16Type, + capacity, + LargeStringDictionaryBuilder + ), + DataType::Int16 => string_dict_builder_inner_helper!( + Int16Type, + capacity, + LargeStringDictionaryBuilder + ), + DataType::Int32 => string_dict_builder_inner_helper!( + Int32Type, + capacity, + LargeStringDictionaryBuilder + ), + DataType::Int64 => string_dict_builder_inner_helper!( + Int64Type, + capacity, + LargeStringDictionaryBuilder + ), + DataType::UInt8 => string_dict_builder_inner_helper!( + UInt8Type, + capacity, + LargeStringDictionaryBuilder + ), + DataType::UInt16 => { + string_dict_builder_inner_helper!( + UInt16Type, + capacity, + LargeStringDictionaryBuilder + ) + } + DataType::UInt32 => { + string_dict_builder_inner_helper!( + UInt32Type, + capacity, + LargeStringDictionaryBuilder + ) + } + DataType::UInt64 => { + string_dict_builder_inner_helper!( + UInt64Type, + capacity, + LargeStringDictionaryBuilder + ) + } + _ => unreachable!(""), + } + } + t => panic!("Data type {t:?} is not currently supported"), + } +} + +fn make_batch( + schema: SchemaRef, + mut arrays: Vec>, +) -> ArrowResult { + let columns = arrays.iter_mut().map(|array| array.finish()).collect(); + RecordBatch::try_new(schema, columns) +} + +/// Checksum algorithms for writing IPC bytes. +#[derive(Debug, Clone)] +pub(crate) enum ChecksumAlgorithm { + /// CRC32 checksum algorithm. + CRC32(Option), + /// Adler32 checksum algorithm. + Adler32(Option), +} + +pub(crate) fn compute_checksum( + cursor: &mut Cursor<&mut Vec>, + checksum_algorithm: &ChecksumAlgorithm, +) -> Result { + match checksum_algorithm { + ChecksumAlgorithm::CRC32(checksum) => { + let mut hasher = if let Some(initial) = checksum { + Hasher::new_with_initial(*initial) + } else { + Hasher::new() + }; + std::io::Seek::seek(cursor, SeekFrom::Start(0))?; + hasher.update(cursor.chunk()); + + let checksum = hasher.finalize(); + Ok(ChecksumAlgorithm::CRC32(Some(checksum))) + } + ChecksumAlgorithm::Adler32(checksum) => { + let mut hasher = if let Some(initial) = checksum { + // Note that Adler32 initial state is not zero. + // i.e., `Adler32::from_checksum(0)` is not the same as `Adler32::new()`. + Adler32::from_checksum(*initial) + } else { + Adler32::new() + }; + std::io::Seek::seek(cursor, SeekFrom::Start(0))?; + hasher.write(cursor.chunk()); + + let checksum = hasher.finish(); + Ok(ChecksumAlgorithm::Adler32(Some(checksum))) + } + } +} + +/// Writes given record batch as Arrow IPC bytes into given writer. +/// Returns number of bytes written. +pub(crate) fn write_ipc_compressed( + batch: &RecordBatch, + output: &mut W, +) -> Result { + if batch.num_rows() == 0 { + return Ok(0); + } + let start_pos = output.stream_position()?; + + // write ipc_length placeholder + output.write_all(&[0u8; 8])?; + + // write ipc data + // TODO: make compression level configurable + let mut arrow_writer = StreamWriter::try_new(zstd::Encoder::new(output, 1)?, &batch.schema())?; + arrow_writer.write(batch)?; + arrow_writer.finish()?; + + let zwriter = arrow_writer.into_inner()?; + let output = zwriter.finish()?; + let end_pos = output.stream_position()?; + let ipc_length = end_pos - start_pos - 8; + + // fill ipc length + output.seek(SeekFrom::Start(start_pos))?; + output.write_all(&ipc_length.to_le_bytes()[..])?; + + output.seek(SeekFrom::Start(end_pos))?; + Ok((end_pos - start_pos) as usize) +} + +/// A stream that yields no record batches which represent end of output. +pub struct EmptyStream { + /// Schema representing the data + schema: SchemaRef, +} + +impl EmptyStream { + /// Create an iterator for a vector of record batches + pub fn try_new(schema: SchemaRef) -> Result { + Ok(Self { schema }) + } +} + +impl Stream for EmptyStream { + type Item = Result; + + fn poll_next(self: std::pin::Pin<&mut Self>, _: &mut Context<'_>) -> Poll> { + Poll::Ready(None) + } +} + +impl RecordBatchStream for EmptyStream { + /// Get the schema + fn schema(&self) -> SchemaRef { + self.schema.clone() + } +} diff --git a/core/src/execution/jni_api.rs b/core/src/execution/jni_api.rs index e940a0853..9981cece3 100644 --- a/core/src/execution/jni_api.rs +++ b/core/src/execution/jni_api.rs @@ -48,14 +48,15 @@ use crate::{ errors::{try_unwrap_or_throw, CometError}, execution::{ datafusion::planner::PhysicalPlanner, metrics::utils::update_comet_metric, - serde::to_arrow_datatype, spark_operator::Operator, + serde::to_arrow_datatype, shuffle::row::process_sorted_row_partition, sort::RdxSort, + spark_operator::Operator, }, jvm_bridge::{jni_new_global_ref, JVMClasses}, }; use futures::stream::StreamExt; use jni::{ objects::{AutoArray, GlobalRef}, - sys::{jbooleanArray, jobjectArray}, + sys::{jboolean, jbooleanArray, jdouble, jintArray, jobjectArray, jstring}, }; use tokio::runtime::Runtime; @@ -505,3 +506,82 @@ fn get_execution_context<'a>(id: i64) -> &'a mut ExecutionContext { .expect("Comet execution context shouldn't be null!") } } + +#[no_mangle] +/// Used by Comet shuffle external sorter to write sorted records to disk. +pub extern "system" fn Java_org_apache_comet_Native_writeSortedFileNative( + env: JNIEnv, + _class: JClass, + row_addresses: jlongArray, + row_sizes: jintArray, + serialized_datatypes: jobjectArray, + file_path: jstring, + prefer_dictionary_ratio: jdouble, + checksum_enabled: jboolean, + checksum_algo: jint, + current_checksum: jlong, +) -> jlongArray { + try_unwrap_or_throw(env, || { + let row_num = env.get_array_length(row_addresses)? as usize; + + let data_types = convert_datatype_arrays(&env, serialized_datatypes)?; + + let row_addresses = env.get_long_array_elements(row_addresses, ReleaseMode::NoCopyBack)?; + let row_sizes = env.get_int_array_elements(row_sizes, ReleaseMode::NoCopyBack)?; + + let row_addresses_ptr = row_addresses.as_ptr(); + let row_sizes_ptr = row_sizes.as_ptr(); + + let output_path: String = env.get_string(JString::from(file_path)).unwrap().into(); + + let checksum_enabled = checksum_enabled == 1; + let current_checksum = if current_checksum == i64::MIN { + // Initial checksum is not available. + None + } else { + Some(current_checksum as u32) + }; + + let (written_bytes, checksum) = process_sorted_row_partition( + row_num, + row_addresses_ptr, + row_sizes_ptr, + &data_types, + output_path, + prefer_dictionary_ratio, + checksum_enabled, + checksum_algo, + current_checksum, + )?; + + let checksum = if let Some(checksum) = checksum { + checksum as i64 + } else { + // Spark checksums (CRC32 or Adler32) are both u32, so we use i64::MIN to indicate + // checksum is not available. + i64::MIN + }; + + let long_array = env.new_long_array(2)?; + env.set_long_array_region(long_array, 0, &[written_bytes, checksum])?; + + Ok(long_array) + }) +} + +#[no_mangle] +/// Used by Comet shuffle external sorter to sort in-memory row partition ids. +pub extern "system" fn Java_org_apache_comet_Native_sortRowPartitionsNative( + env: JNIEnv, + _class: JClass, + address: jlong, + size: jlong, +) { + try_unwrap_or_throw(env, || { + // SAFETY: JVM unsafe memory allocation is aligned with long. + let array = unsafe { std::slice::from_raw_parts_mut(address as *mut i64, size as usize) }; + array.rdxsort(); + + Ok(()) + }) +} diff --git a/core/src/execution/mod.rs b/core/src/execution/mod.rs index b0c60cc52..4c57ad8eb 100644 --- a/core/src/execution/mod.rs +++ b/core/src/execution/mod.rs @@ -24,6 +24,8 @@ pub mod kernels; // for benchmarking mod metrics; pub mod operators; pub mod serde; +pub mod shuffle; +pub(crate) mod sort; mod timezone; pub(crate) mod utils; diff --git a/core/src/execution/proto/operator.proto b/core/src/execution/proto/operator.proto index f4f785396..5b07cb30b 100644 --- a/core/src/execution/proto/operator.proto +++ b/core/src/execution/proto/operator.proto @@ -38,7 +38,8 @@ message Operator { Sort sort = 103; HashAggregate hash_agg = 104; Limit limit = 105; - Expand expand = 106; + ShuffleWriter shuffle_writer = 106; + Expand expand = 107; } } @@ -71,6 +72,12 @@ message Limit { int32 offset = 2; } +message ShuffleWriter { + spark.spark_partitioning.Partitioning partitioning = 1; + string output_data_file = 3; + string output_index_file = 4; +} + enum AggregateMode { Partial = 0; Final = 1; diff --git a/core/src/execution/shuffle/list.rs b/core/src/execution/shuffle/list.rs new file mode 100644 index 000000000..21c3b57d5 --- /dev/null +++ b/core/src/execution/shuffle/list.rs @@ -0,0 +1,344 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use crate::{ + errors::CometError, + execution::shuffle::row::{append_field, SparkUnsafeObject}, +}; +use arrow_array::builder::{ + ArrayBuilder, BinaryBuilder, BooleanBuilder, Date32Builder, Decimal128Builder, Float32Builder, + Float64Builder, Int16Builder, Int32Builder, Int64Builder, Int8Builder, ListBuilder, + StringBuilder, StructBuilder, TimestampMicrosecondBuilder, +}; +use arrow_schema::{DataType, TimeUnit}; + +pub struct SparkUnsafeArray { + row_addr: i64, + row_size: i32, + num_elements: usize, + element_offset: i64, +} + +impl SparkUnsafeObject for SparkUnsafeArray { + fn get_row_addr(&self) -> i64 { + self.row_addr + } + + fn get_element_offset(&self, index: usize, element_size: usize) -> *const u8 { + (self.element_offset + (index * element_size) as i64) as *const u8 + } +} + +impl SparkUnsafeArray { + /// Creates a `SparkUnsafeArray` which points to the given address and size in bytes. + pub fn new(addr: i64, size: i32) -> Self { + // Read the number of elements from the first 8 bytes. + let slice: &[u8] = unsafe { std::slice::from_raw_parts(addr as *const u8, 8) }; + let num_elements = i64::from_le_bytes(slice.try_into().unwrap()); + + if num_elements < 0 { + panic!("Negative number of elements: {}", num_elements); + } + + if num_elements > i32::MAX as i64 { + panic!("Number of elements should <= i32::MAX: {}", num_elements); + } + + Self { + row_addr: addr, + row_size: size, + num_elements: num_elements as usize, + element_offset: addr + Self::get_header_portion_in_bytes(num_elements), + } + } + + pub(crate) fn get_num_elements(&self) -> usize { + self.num_elements + } + + /// Returns the size of array header in bytes. + #[inline] + const fn get_header_portion_in_bytes(num_fields: i64) -> i64 { + 8 + ((num_fields + 63) / 64) * 8 + } + + /// Returns true if the null bit at the given index of the array is set. + #[inline] + pub(crate) fn is_null_at(&self, index: usize) -> bool { + unsafe { + let mask: i64 = 1i64 << (index & 0x3f); + let word_offset = (self.row_addr + 8 + (((index >> 6) as i64) << 3)) as *const i64; + let word: i64 = *word_offset; + (word & mask) != 0 + } + } +} + +/// A macro defines a function that appends the given list stored in `SparkUnsafeArray` into +/// `ListBuilder`. +macro_rules! define_append_element { + ($func:ident, $builder_type:ty, $accessor:expr) => { + #[allow(clippy::redundant_closure_call)] + fn $func( + list_builder: &mut ListBuilder<$builder_type>, + list: &SparkUnsafeArray, + idx: usize, + ) { + let element_builder: &mut $builder_type = list_builder.values(); + let is_null = list.is_null_at(idx); + + if is_null { + element_builder.append_null(); + } else { + $accessor(element_builder, list, idx); + } + } + }; +} + +define_append_element!( + append_boolean_element, + BooleanBuilder, + |builder: &mut BooleanBuilder, list: &SparkUnsafeArray, idx: usize| builder + .append_value(list.get_boolean(idx)) +); +define_append_element!( + append_int8_element, + Int8Builder, + |builder: &mut Int8Builder, list: &SparkUnsafeArray, idx: usize| builder + .append_value(list.get_byte(idx)) +); +define_append_element!( + append_int16_element, + Int16Builder, + |builder: &mut Int16Builder, list: &SparkUnsafeArray, idx: usize| builder + .append_value(list.get_short(idx)) +); +define_append_element!( + append_int32_element, + Int32Builder, + |builder: &mut Int32Builder, list: &SparkUnsafeArray, idx: usize| builder + .append_value(list.get_int(idx)) +); +define_append_element!( + append_int64_element, + Int64Builder, + |builder: &mut Int64Builder, list: &SparkUnsafeArray, idx: usize| builder + .append_value(list.get_long(idx)) +); +define_append_element!( + append_float32_element, + Float32Builder, + |builder: &mut Float32Builder, list: &SparkUnsafeArray, idx: usize| builder + .append_value(list.get_float(idx)) +); +define_append_element!( + append_float64_element, + Float64Builder, + |builder: &mut Float64Builder, list: &SparkUnsafeArray, idx: usize| builder + .append_value(list.get_double(idx)) +); +define_append_element!( + append_date32_element, + Date32Builder, + |builder: &mut Date32Builder, list: &SparkUnsafeArray, idx: usize| builder + .append_value(list.get_date(idx)) +); +define_append_element!( + append_timestamp_element, + TimestampMicrosecondBuilder, + |builder: &mut TimestampMicrosecondBuilder, list: &SparkUnsafeArray, idx: usize| builder + .append_value(list.get_timestamp(idx)) +); +define_append_element!( + append_binary_element, + BinaryBuilder, + |builder: &mut BinaryBuilder, list: &SparkUnsafeArray, idx: usize| builder + .append_value(list.get_binary(idx)) +); +define_append_element!( + append_string_element, + StringBuilder, + |builder: &mut StringBuilder, list: &SparkUnsafeArray, idx: usize| builder + .append_value(list.get_string(idx)) +); + +/// Appending the given list stored in `SparkUnsafeArray` into `ListBuilder`. +/// `element_dt` is the data type of the list element. `list_builder` is the list builder. +/// `list` is the list stored in `SparkUnsafeArray`. +pub fn append_list_element( + element_dt: &DataType, + list_builder: &mut ListBuilder, + list: &SparkUnsafeArray, +) -> Result<(), CometError> { + for idx in 0..list.get_num_elements() { + match element_dt { + DataType::Boolean => append_boolean_element( + list_builder + .as_any_mut() + .downcast_mut::>() + .unwrap(), + list, + idx, + ), + DataType::Int8 => append_int8_element( + list_builder + .as_any_mut() + .downcast_mut::>() + .unwrap(), + list, + idx, + ), + DataType::Int16 => append_int16_element( + list_builder + .as_any_mut() + .downcast_mut::>() + .unwrap(), + list, + idx, + ), + DataType::Int32 => append_int32_element( + list_builder + .as_any_mut() + .downcast_mut::>() + .unwrap(), + list, + idx, + ), + DataType::Int64 => append_int64_element( + list_builder + .as_any_mut() + .downcast_mut::>() + .unwrap(), + list, + idx, + ), + DataType::Float32 => append_float32_element( + list_builder + .as_any_mut() + .downcast_mut::>() + .unwrap(), + list, + idx, + ), + DataType::Float64 => append_float64_element( + list_builder + .as_any_mut() + .downcast_mut::>() + .unwrap(), + list, + idx, + ), + DataType::Date32 => append_date32_element( + list_builder + .as_any_mut() + .downcast_mut::>() + .unwrap(), + list, + idx, + ), + DataType::Timestamp(TimeUnit::Microsecond, _) => append_timestamp_element( + list_builder + .as_any_mut() + .downcast_mut::>() + .unwrap(), + list, + idx, + ), + DataType::Binary => append_binary_element( + list_builder + .as_any_mut() + .downcast_mut::>() + .unwrap(), + list, + idx, + ), + DataType::Utf8 => append_string_element( + list_builder + .as_any_mut() + .downcast_mut::>() + .unwrap(), + list, + idx, + ), + DataType::Decimal128(p, _) => { + let element_builder: &mut Decimal128Builder = list_builder + .values() + .as_any_mut() + .downcast_mut::() + .unwrap(); + let is_null = list.is_null_at(idx); + + if is_null { + element_builder.append_null(); + } else { + element_builder.append_value(list.get_decimal(idx, *p)) + } + } + // TODO: support nested list + // If the element is a list, we need to get the nested list builder by + // `list_builder.values()` and downcast to correct type, i.e., ListBuilder. + // But we don't know the type `U` so we cannot downcast to correct type + // and recursively call `append_list_element`. Later once we upgrade to + // latest Arrow, the `T` of `ListBuilder` could be `Box` + // which erase the deep type of the builder. + /* + DataType::List(field) => { + let element_builder: &mut ListBuilder<_> = list_builder + .values() + .as_any_mut() + .downcast_mut::>() + .unwrap(); + let is_null = list.is_null_at(idx); + + if is_null { + element_builder.append_null(); + } else { + append_list_element(field.data_type(), element_builder, list); + } + } + */ + DataType::Struct(fields) => { + let element_builder: &mut StructBuilder = list_builder + .values() + .as_any_mut() + .downcast_mut::() + .unwrap(); + let is_null = list.is_null_at(idx); + + if is_null { + element_builder.append_null(); + } else { + let nested_row = list.get_struct(idx, fields.len()); + element_builder.append(true); + for (field_idx, field) in fields.into_iter().enumerate() { + append_field(field.data_type(), element_builder, &nested_row, field_idx); + } + } + } + _ => { + return Err(CometError::Internal(format!( + "Unsupported data type in list element: {:?}", + element_dt + ))) + } + } + } + list_builder.append(true); + + Ok(()) +} diff --git a/core/src/execution/shuffle/mod.rs b/core/src/execution/shuffle/mod.rs new file mode 100644 index 000000000..0b11597d0 --- /dev/null +++ b/core/src/execution/shuffle/mod.rs @@ -0,0 +1,19 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +mod list; +pub mod row; diff --git a/core/src/execution/shuffle/row.rs b/core/src/execution/shuffle/row.rs new file mode 100644 index 000000000..110b57827 --- /dev/null +++ b/core/src/execution/shuffle/row.rs @@ -0,0 +1,933 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Utils for supporting native sort-based columnar shuffle. + +use crate::{ + errors::CometError, + execution::{ + datafusion::shuffle_writer::{compute_checksum, write_ipc_compressed, ChecksumAlgorithm}, + shuffle::list::{append_list_element, SparkUnsafeArray}, + utils::bytes_to_i128, + }, +}; +use arrow::compute::cast; +use arrow_array::{ + builder::{ + ArrayBuilder, BinaryBuilder, BinaryDictionaryBuilder, BooleanBuilder, Date32Builder, + Decimal128Builder, Float32Builder, Float64Builder, Int16Builder, Int32Builder, + Int64Builder, Int8Builder, ListBuilder, StringBuilder, StringDictionaryBuilder, + StructBuilder, TimestampMicrosecondBuilder, + }, + types::Int32Type, + Array, ArrayRef, RecordBatch, +}; +use arrow_schema::{DataType, Field, Schema, TimeUnit}; +use jni::sys::{jint, jlong}; +use std::{ + fs::OpenOptions, + io::{Cursor, Seek, SeekFrom, Write}, + str::from_utf8, + sync::Arc, +}; + +const WORD_SIZE: i64 = 8; +const MAX_LONG_DIGITS: u8 = 18; +const LIST_BUILDER_CAPACITY: usize = 100; + +/// A common trait for Spark Unsafe classes that can be used to access the underlying data, +/// e.g., `UnsafeRow` and `UnsafeArray`. This defines a set of methods that can be used to +/// access the underlying data with index. +pub trait SparkUnsafeObject { + /// Returns the address of the row. + fn get_row_addr(&self) -> i64; + + /// Returns the offset of the element at the given index. + fn get_element_offset(&self, index: usize, element_size: usize) -> *const u8; + + /// Returns the offset and length of the element at the given index. + #[inline] + fn get_offset_and_len(&self, index: usize) -> (i32, i32) { + let offset_and_size = self.get_long(index); + let offset = (offset_and_size >> 32) as i32; + let len = offset_and_size as i32; + (offset, len) + } + + /// Returns boolean value at the given index of the object. + fn get_boolean(&self, index: usize) -> bool { + let addr = self.get_element_offset(index, 1); + unsafe { *addr != 0 } + } + + /// Returns byte value at the given index of the object. + fn get_byte(&self, index: usize) -> i8 { + let addr = self.get_element_offset(index, 1); + let slice: &[u8] = unsafe { std::slice::from_raw_parts(addr, 1) }; + i8::from_le_bytes(slice.try_into().unwrap()) + } + + /// Returns short value at the given index of the object. + fn get_short(&self, index: usize) -> i16 { + let addr = self.get_element_offset(index, 2); + let slice: &[u8] = unsafe { std::slice::from_raw_parts(addr, 2) }; + i16::from_le_bytes(slice.try_into().unwrap()) + } + + /// Returns integer value at the given index of the object. + fn get_int(&self, index: usize) -> i32 { + let addr = self.get_element_offset(index, 4); + let slice: &[u8] = unsafe { std::slice::from_raw_parts(addr, 4) }; + i32::from_le_bytes(slice.try_into().unwrap()) + } + + /// Returns long value at the given index of the object. + fn get_long(&self, index: usize) -> i64 { + let addr = self.get_element_offset(index, 8); + let slice: &[u8] = unsafe { std::slice::from_raw_parts(addr, 8) }; + i64::from_le_bytes(slice.try_into().unwrap()) + } + + /// Returns float value at the given index of the object. + fn get_float(&self, index: usize) -> f32 { + let addr = self.get_element_offset(index, 4); + let slice: &[u8] = unsafe { std::slice::from_raw_parts(addr, 4) }; + f32::from_le_bytes(slice.try_into().unwrap()) + } + + /// Returns double value at the given index of the object. + fn get_double(&self, index: usize) -> f64 { + let addr = self.get_element_offset(index, 8); + let slice: &[u8] = unsafe { std::slice::from_raw_parts(addr, 8) }; + f64::from_le_bytes(slice.try_into().unwrap()) + } + + /// Returns string value at the given index of the object. + fn get_string(&self, index: usize) -> &str { + let (offset, len) = self.get_offset_and_len(index); + let addr = self.get_row_addr() + offset as i64; + let slice: &[u8] = unsafe { std::slice::from_raw_parts(addr as *const u8, len as usize) }; + + from_utf8(slice).unwrap() + } + + /// Returns binary value at the given index of the object. + fn get_binary(&self, index: usize) -> &[u8] { + let (offset, len) = self.get_offset_and_len(index); + let addr = self.get_row_addr() + offset as i64; + unsafe { std::slice::from_raw_parts(addr as *const u8, len as usize) } + } + + /// Returns date value at the given index of the object. + fn get_date(&self, index: usize) -> i32 { + let addr = self.get_element_offset(index, 4); + let slice: &[u8] = unsafe { std::slice::from_raw_parts(addr, 4) }; + i32::from_le_bytes(slice.try_into().unwrap()) + } + + /// Returns timestamp value at the given index of the object. + fn get_timestamp(&self, index: usize) -> i64 { + let addr = self.get_element_offset(index, 8); + let slice: &[u8] = unsafe { std::slice::from_raw_parts(addr, 8) }; + i64::from_le_bytes(slice.try_into().unwrap()) + } + + /// Returns decimal value at the given index of the object. + fn get_decimal(&self, index: usize, precision: u8) -> i128 { + if precision <= MAX_LONG_DIGITS { + self.get_long(index) as i128 + } else { + let slice = self.get_binary(index); + bytes_to_i128(slice) + } + } + + /// Returns struct value at the given index of the object. + fn get_struct(&self, index: usize, num_fields: usize) -> SparkUnsafeRow { + let (offset, len) = self.get_offset_and_len(index); + let mut row = SparkUnsafeRow::new_with_num_fields(num_fields); + row.point_to(self.get_row_addr() + offset as i64, len); + + row + } + + /// Returns array value at the given index of the object. + fn get_array(&self, index: usize) -> SparkUnsafeArray { + let (offset, len) = self.get_offset_and_len(index); + SparkUnsafeArray::new(self.get_row_addr() + offset as i64, len) + } +} + +pub struct SparkUnsafeRow { + row_addr: i64, + row_size: i32, + row_bitset_width: i64, +} + +impl SparkUnsafeObject for SparkUnsafeRow { + fn get_row_addr(&self) -> i64 { + self.row_addr + } + + fn get_element_offset(&self, index: usize, _: usize) -> *const u8 { + (self.row_addr + self.row_bitset_width + (index * 8) as i64) as *const u8 + } +} + +impl SparkUnsafeRow { + fn new(schema: &Vec) -> Self { + Self { + row_addr: -1, + row_size: -1, + row_bitset_width: Self::get_row_bitset_width(schema.len()) as i64, + } + } + + /// Calculate the width of the bitset for the row in bytes. + /// The logic is from Spark `UnsafeRow.calculateBitSetWidthInBytes`. + #[inline] + pub const fn get_row_bitset_width(num_fields: usize) -> usize { + ((num_fields + 63) / 64) * 8 + } + + pub fn new_with_num_fields(num_fields: usize) -> Self { + Self { + row_addr: -1, + row_size: -1, + row_bitset_width: Self::get_row_bitset_width(num_fields) as i64, + } + } + + /// Points the row to the given slice. + pub fn point_to_slice(&mut self, slice: &[u8]) { + self.row_addr = slice.as_ptr() as i64; + self.row_size = slice.len() as i32; + } + + /// Points the row to the given address with specified row size. + fn point_to(&mut self, row_addr: i64, row_size: i32) { + self.row_addr = row_addr; + self.row_size = row_size; + } + + pub fn get_row_size(&self) -> i32 { + self.row_size + } + + /// Returns true if the null bit at the given index of the row is set. + #[inline] + pub(crate) fn is_null_at(&self, index: usize) -> bool { + unsafe { + let mask: i64 = 1i64 << (index & 0x3f); + let word_offset = (self.row_addr + (((index >> 6) as i64) << 3)) as *const i64; + let word: i64 = *word_offset; + (word & mask) != 0 + } + } + + /// Unsets the null bit at the given index of the row, i.e., set the bit to 0 (not null). + pub fn set_not_null_at(&mut self, index: usize) { + unsafe { + let mask: i64 = 1i64 << (index & 0x3f); + let word_offset = (self.row_addr + (((index >> 6) as i64) << 3)) as *mut i64; + let word: i64 = *word_offset; + *word_offset = word & !mask; + } + } +} + +macro_rules! downcast_builder { + ($builder_type:ty, $builder:expr) => { + $builder.into_box_any().downcast::<$builder_type>().unwrap() + }; +} + +/// Appends field of row to the given struct builder. `dt` is the data type of the field. +/// `struct_builder` is the struct builder of the row. `row` is the row that contains the field. +/// `idx` is the index of the field in the row. +#[allow(clippy::redundant_closure_call)] +pub(crate) fn append_field( + dt: &DataType, + struct_builder: &mut StructBuilder, + row: &SparkUnsafeRow, + idx: usize, +) { + /// A macro for generating code of appending value into field builder of Arrow struct builder. + macro_rules! append_field_to_builder { + ($builder_type:ty, $accessor:expr) => {{ + let field_builder = struct_builder.field_builder::<$builder_type>(idx).unwrap(); + let is_null = row.is_null_at(idx); + + if is_null { + field_builder.append_null(); + } else { + $accessor(field_builder); + } + }}; + } + + /// A macro for generating code of appending value into list field builder of Arrow struct + /// builder. + macro_rules! append_list_field_to_builder { + ($builder_type:ty, $element_dt:expr) => {{ + let field_builder = struct_builder + .field_builder::>(idx) + .unwrap(); + let is_null = row.is_null_at(idx); + + if is_null { + field_builder.append_null(); + } else { + append_list_element::<$builder_type>( + $element_dt, + field_builder, + &row.get_array(idx), + ) + .unwrap() + } + }}; + } + + match dt { + DataType::Boolean => { + append_field_to_builder!(BooleanBuilder, |builder: &mut BooleanBuilder| builder + .append_value(row.get_boolean(idx))); + } + DataType::Int8 => { + append_field_to_builder!(Int8Builder, |builder: &mut Int8Builder| builder + .append_value(row.get_byte(idx))); + } + DataType::Int16 => { + append_field_to_builder!(Int16Builder, |builder: &mut Int16Builder| builder + .append_value(row.get_short(idx))); + } + DataType::Int32 => { + append_field_to_builder!(Int32Builder, |builder: &mut Int32Builder| builder + .append_value(row.get_int(idx))); + } + DataType::Int64 => { + append_field_to_builder!(Int64Builder, |builder: &mut Int64Builder| builder + .append_value(row.get_long(idx))); + } + DataType::Float32 => { + append_field_to_builder!(Float32Builder, |builder: &mut Float32Builder| builder + .append_value(row.get_float(idx))); + } + DataType::Float64 => { + append_field_to_builder!(Float64Builder, |builder: &mut Float64Builder| builder + .append_value(row.get_double(idx))); + } + DataType::Date32 => { + append_field_to_builder!(Date32Builder, |builder: &mut Date32Builder| builder + .append_value(row.get_date(idx))); + } + DataType::Timestamp(TimeUnit::Microsecond, _) => { + append_field_to_builder!( + TimestampMicrosecondBuilder, + |builder: &mut TimestampMicrosecondBuilder| builder + .append_value(row.get_timestamp(idx)) + ); + } + DataType::Binary => { + append_field_to_builder!(BinaryBuilder, |builder: &mut BinaryBuilder| builder + .append_value(row.get_binary(idx))); + } + DataType::Utf8 => { + append_field_to_builder!(StringBuilder, |builder: &mut StringBuilder| builder + .append_value(row.get_string(idx))); + } + DataType::Decimal128(p, _) => { + append_field_to_builder!(Decimal128Builder, |builder: &mut Decimal128Builder| builder + .append_value(row.get_decimal(idx, *p))); + } + DataType::Struct(fields) => { + append_field_to_builder!(StructBuilder, |builder: &mut StructBuilder| { + let nested_row = row.get_struct(idx, fields.len()); + builder.append(true); + for (field_idx, field) in fields.into_iter().enumerate() { + append_field(field.data_type(), builder, &nested_row, field_idx); + } + }); + } + DataType::List(field) => match field.data_type() { + DataType::Boolean => { + append_list_field_to_builder!(BooleanBuilder, field.data_type()); + } + DataType::Int8 => { + append_list_field_to_builder!(Int8Builder, field.data_type()); + } + DataType::Int16 => { + append_list_field_to_builder!(Int16Builder, field.data_type()); + } + DataType::Int32 => { + append_list_field_to_builder!(Int32Builder, field.data_type()); + } + DataType::Int64 => { + append_list_field_to_builder!(Int64Builder, field.data_type()); + } + DataType::Float32 => { + append_list_field_to_builder!(Float32Builder, field.data_type()); + } + DataType::Float64 => { + append_list_field_to_builder!(Float64Builder, field.data_type()); + } + DataType::Timestamp(TimeUnit::Microsecond, _) => { + append_list_field_to_builder!(TimestampMicrosecondBuilder, field.data_type()); + } + DataType::Date32 => { + append_list_field_to_builder!(Date32Builder, field.data_type()); + } + DataType::Binary => { + append_list_field_to_builder!(BinaryBuilder, field.data_type()); + } + DataType::Utf8 => { + append_list_field_to_builder!(StringBuilder, field.data_type()); + } + DataType::Struct(_) => { + append_list_field_to_builder!(StructBuilder, field.data_type()); + } + DataType::Decimal128(_, _) => { + append_list_field_to_builder!(Decimal128Builder, field.data_type()); + } + _ => unreachable!("Unsupported data type of struct field: {:?}", dt), + }, + _ => { + unreachable!("Unsupported data type of struct field: {:?}", dt) + } + } +} + +/// Appends column of rows to the given array builder. +#[allow(clippy::redundant_closure_call, clippy::too_many_arguments)] +pub(crate) fn append_columns( + row_addresses_ptr: *mut jlong, + row_sizes_ptr: *mut jint, + row_start: usize, + row_end: usize, + schema: &Vec, + column_idx: usize, + builder: &mut Box, + prefer_dictionary_ratio: f64, +) { + /// A macro for generating code of appending values into Arrow array builders. + macro_rules! append_column_to_builder { + ($builder_type:ty, $accessor:expr) => {{ + let builder = builder + .as_any_mut() + .downcast_mut::<$builder_type>() + .unwrap(); + let mut row = SparkUnsafeRow::new(schema); + + for i in row_start..row_end { + let row_addr = unsafe { *row_addresses_ptr.add(i) }; + let row_size = unsafe { *row_sizes_ptr.add(i) }; + row.point_to(row_addr, row_size); + + let is_null = row.is_null_at(column_idx); + + if is_null { + builder.append_null(); + } else { + $accessor(builder, &row, column_idx); + } + } + }}; + } + + /// A macro for generating code of appending values into Arrow `ListBuilder`. + macro_rules! append_column_to_list_builder { + ($builder_type:ty, $element_dt:expr) => {{ + let builder = builder + .as_any_mut() + .downcast_mut::>() + .unwrap(); + let mut row = SparkUnsafeRow::new(schema); + + for i in row_start..row_end { + let row_addr = unsafe { *row_addresses_ptr.add(i) }; + let row_size = unsafe { *row_sizes_ptr.add(i) }; + row.point_to(row_addr, row_size); + + let is_null = row.is_null_at(column_idx); + + if is_null { + builder.append_null(); + } else { + append_list_element::<$builder_type>( + $element_dt, + builder, + &row.get_array(column_idx), + ) + .unwrap() + } + } + }}; + } + + let dt = &schema[column_idx]; + + match dt { + DataType::Boolean => { + append_column_to_builder!( + BooleanBuilder, + |builder: &mut BooleanBuilder, row: &SparkUnsafeRow, idx| builder + .append_value(row.get_boolean(idx)) + ); + } + DataType::Int8 => { + append_column_to_builder!( + Int8Builder, + |builder: &mut Int8Builder, row: &SparkUnsafeRow, idx| builder + .append_value(row.get_byte(idx)) + ); + } + DataType::Int16 => { + append_column_to_builder!( + Int16Builder, + |builder: &mut Int16Builder, row: &SparkUnsafeRow, idx| builder + .append_value(row.get_short(idx)) + ); + } + DataType::Int32 => { + append_column_to_builder!( + Int32Builder, + |builder: &mut Int32Builder, row: &SparkUnsafeRow, idx| builder + .append_value(row.get_int(idx)) + ); + } + DataType::Int64 => { + append_column_to_builder!( + Int64Builder, + |builder: &mut Int64Builder, row: &SparkUnsafeRow, idx| builder + .append_value(row.get_long(idx)) + ); + } + DataType::Float32 => { + append_column_to_builder!( + Float32Builder, + |builder: &mut Float32Builder, row: &SparkUnsafeRow, idx| builder + .append_value(row.get_float(idx)) + ); + } + DataType::Float64 => { + append_column_to_builder!( + Float64Builder, + |builder: &mut Float64Builder, row: &SparkUnsafeRow, idx| builder + .append_value(row.get_double(idx)) + ); + } + DataType::Decimal128(p, _) => { + append_column_to_builder!( + Decimal128Builder, + |builder: &mut Decimal128Builder, row: &SparkUnsafeRow, idx| builder + .append_value(row.get_decimal(idx, *p)) + ); + } + DataType::Utf8 => { + if prefer_dictionary_ratio > 1.0 { + append_column_to_builder!( + StringDictionaryBuilder, + |builder: &mut StringDictionaryBuilder, + row: &SparkUnsafeRow, + idx| builder.append_value(row.get_string(idx)) + ); + } else { + append_column_to_builder!( + StringBuilder, + |builder: &mut StringBuilder, row: &SparkUnsafeRow, idx| builder + .append_value(row.get_string(idx)) + ); + } + } + DataType::Binary => { + if prefer_dictionary_ratio > 1.0 { + append_column_to_builder!( + BinaryDictionaryBuilder, + |builder: &mut BinaryDictionaryBuilder, + row: &SparkUnsafeRow, + idx| builder.append_value(row.get_binary(idx)) + ); + } else { + append_column_to_builder!( + BinaryBuilder, + |builder: &mut BinaryBuilder, row: &SparkUnsafeRow, idx| builder + .append_value(row.get_binary(idx)) + ); + } + } + DataType::Date32 => { + append_column_to_builder!( + Date32Builder, + |builder: &mut Date32Builder, row: &SparkUnsafeRow, idx| builder + .append_value(row.get_date(idx)) + ); + } + DataType::Timestamp(TimeUnit::Microsecond, _) => { + append_column_to_builder!( + TimestampMicrosecondBuilder, + |builder: &mut TimestampMicrosecondBuilder, row: &SparkUnsafeRow, idx| builder + .append_value(row.get_timestamp(idx)) + ); + } + DataType::List(field) => match field.data_type() { + DataType::Boolean => { + append_column_to_list_builder!(BooleanBuilder, field.data_type()); + } + DataType::Int8 => { + append_column_to_list_builder!(Int8Builder, field.data_type()); + } + DataType::Int16 => { + append_column_to_list_builder!(Int16Builder, field.data_type()); + } + DataType::Int32 => { + append_column_to_list_builder!(Int32Builder, field.data_type()); + } + DataType::Int64 => { + append_column_to_list_builder!(Int64Builder, field.data_type()); + } + DataType::Float32 => { + append_column_to_list_builder!(Float32Builder, field.data_type()); + } + DataType::Float64 => { + append_column_to_list_builder!(Float64Builder, field.data_type()); + } + DataType::Timestamp(TimeUnit::Microsecond, _) => { + append_column_to_list_builder!(TimestampMicrosecondBuilder, field.data_type()); + } + DataType::Date32 => { + append_column_to_list_builder!(Date32Builder, field.data_type()); + } + DataType::Binary => { + append_column_to_list_builder!(BinaryBuilder, field.data_type()); + } + DataType::Utf8 => { + append_column_to_list_builder!(StringBuilder, field.data_type()); + } + DataType::Struct(_) => { + append_column_to_list_builder!(StructBuilder, field.data_type()); + } + DataType::Decimal128(_, _) => { + append_column_to_list_builder!(Decimal128Builder, field.data_type()); + } + _ => unreachable!("Unsupported data type of list element: {:?}", dt), + }, + DataType::Struct(fields) => { + append_column_to_builder!( + StructBuilder, + |builder: &mut StructBuilder, row: &SparkUnsafeRow, idx| { + let nested_row = row.get_struct(idx, fields.len()); + builder.append(true); + for (idx, field) in fields.into_iter().enumerate() { + append_field(field.data_type(), builder, &nested_row, idx); + } + } + ); + } + _ => { + unreachable!("Unsupported data type of column: {:?}", dt) + } + } +} + +fn make_builders( + dt: &DataType, + row_num: usize, + prefer_dictionary_ratio: f64, +) -> Result, CometError> { + let builder: Box = match dt { + DataType::Boolean => Box::new(BooleanBuilder::with_capacity(row_num)), + DataType::Int8 => Box::new(Int8Builder::with_capacity(row_num)), + DataType::Int16 => Box::new(Int16Builder::with_capacity(row_num)), + DataType::Int32 => Box::new(Int32Builder::with_capacity(row_num)), + DataType::Int64 => Box::new(Int64Builder::with_capacity(row_num)), + DataType::Float32 => Box::new(Float32Builder::with_capacity(row_num)), + DataType::Float64 => Box::new(Float64Builder::with_capacity(row_num)), + DataType::Decimal128(_, _) => { + Box::new(Decimal128Builder::with_capacity(row_num).with_data_type(dt.clone())) + } + DataType::Utf8 => { + if prefer_dictionary_ratio > 1.0 { + Box::new(StringDictionaryBuilder::::with_capacity( + row_num / 2, + row_num, + 1024, + )) + } else { + Box::new(StringBuilder::with_capacity(row_num, 1024)) + } + } + DataType::Binary => { + if prefer_dictionary_ratio > 1.0 { + Box::new(BinaryDictionaryBuilder::::with_capacity( + row_num / 2, + row_num, + 1024, + )) + } else { + Box::new(BinaryBuilder::with_capacity(row_num, 1024)) + } + } + DataType::Date32 => Box::new(Date32Builder::with_capacity(row_num)), + DataType::Timestamp(TimeUnit::Microsecond, _) => { + Box::new(TimestampMicrosecondBuilder::with_capacity(row_num).with_data_type(dt.clone())) + } + DataType::List(field) => { + // Disable dictionary encoding for array element + let value_builder = make_builders(field.data_type(), LIST_BUILDER_CAPACITY, 1.0)?; + match field.data_type() { + DataType::Boolean => { + let builder = downcast_builder!(BooleanBuilder, value_builder); + Box::new(ListBuilder::new(*builder)) + } + DataType::Int8 => { + let builder = downcast_builder!(Int8Builder, value_builder); + Box::new(ListBuilder::new(*builder)) + } + DataType::Int16 => { + let builder = downcast_builder!(Int16Builder, value_builder); + Box::new(ListBuilder::new(*builder)) + } + DataType::Int32 => { + let builder = downcast_builder!(Int32Builder, value_builder); + Box::new(ListBuilder::new(*builder)) + } + DataType::Int64 => { + let builder = downcast_builder!(Int64Builder, value_builder); + Box::new(ListBuilder::new(*builder)) + } + DataType::Float32 => { + let builder = downcast_builder!(Float32Builder, value_builder); + Box::new(ListBuilder::new(*builder)) + } + DataType::Float64 => { + let builder = downcast_builder!(Float64Builder, value_builder); + Box::new(ListBuilder::new(*builder)) + } + DataType::Decimal128(_, _) => { + let builder = downcast_builder!(Decimal128Builder, value_builder); + Box::new(ListBuilder::new(*builder)) + } + DataType::Timestamp(TimeUnit::Microsecond, _) => { + let builder = downcast_builder!(TimestampMicrosecondBuilder, value_builder); + Box::new(ListBuilder::new(*builder)) + } + DataType::Date32 => { + let builder = downcast_builder!(Date32Builder, value_builder); + Box::new(ListBuilder::new(*builder)) + } + DataType::Binary => { + let builder = downcast_builder!(BinaryBuilder, value_builder); + Box::new(ListBuilder::new(*builder)) + } + DataType::Utf8 => { + let builder = downcast_builder!(StringBuilder, value_builder); + Box::new(ListBuilder::new(*builder)) + } + DataType::Struct(_) => { + let builder = downcast_builder!(StructBuilder, value_builder); + Box::new(ListBuilder::new(*builder)) + } + // TODO: nested list is not supported. Due to the design of `ListBuilder`, it has + // a `T: ArrayBuilder` as type parameter. It makes hard to construct an arbitrarily + // nested `ListBuilder`. + DataType::List(_) => { + return Err(CometError::Internal( + "list of list is not supported type".to_string(), + )) + } + _ => { + return Err(CometError::Internal(format!( + "Unsupported type: {:?}", + field.data_type() + ))) + } + } + } + DataType::Struct(fields) => { + let field_builders = fields + .iter() + // Disable dictionary encoding for struct fields + .map(|field| make_builders(field.data_type(), row_num, 1.0)) + .collect::, _>>()?; + + Box::new(StructBuilder::new(fields.clone(), field_builders)) + } + _ => return Err(CometError::Internal(format!("Unsupported type: {:?}", dt))), + }; + + Ok(builder) +} + +/// Processes a sorted row partition and writes the result to the given output path. +#[allow(clippy::too_many_arguments)] +pub fn process_sorted_row_partition( + row_num: usize, + row_addresses_ptr: *mut jlong, + row_sizes_ptr: *mut jint, + schema: &Vec, + output_path: String, + prefer_dictionary_ratio: f64, + checksum_enabled: bool, + checksum_algo: i32, + current_checksum: Option, +) -> Result<(i64, Option), CometError> { + let mut data_builders: Vec> = vec![]; + schema.iter().try_for_each(|dt| { + make_builders(dt, row_num, prefer_dictionary_ratio) + .map(|builder| data_builders.push(builder))?; + Ok::<(), CometError>(()) + })?; + + // Appends rows to the array builders. + let mut row_start: usize = 0; + // TODO: We can tune this parameter automatically based on row size and cache size. + let row_step = 10; + while row_start < row_num { + let row_end = std::cmp::min(row_start + row_step, row_num); + + // For each column, iterating over rows and appending values to corresponding array builder. + for (idx, builder) in data_builders.iter_mut().enumerate() { + append_columns( + row_addresses_ptr, + row_sizes_ptr, + row_start, + row_end, + schema, + idx, + builder, + prefer_dictionary_ratio, + ); + } + + row_start = row_end; + } + + // Writes a record batch generated from the array builders to the output file. + let array_refs: Result, _> = data_builders + .iter_mut() + .zip(schema.iter()) + .map(|(builder, datatype)| builder_to_array(builder, datatype, prefer_dictionary_ratio)) + .collect(); + let batch = make_batch(array_refs?); + + let mut frozen: Vec = vec![]; + let mut cursor = Cursor::new(&mut frozen); + cursor.seek(SeekFrom::End(0))?; + let written = write_ipc_compressed(&batch, &mut cursor)?; + + let checksum = if checksum_enabled { + let checksum = match checksum_algo { + 0 => ChecksumAlgorithm::CRC32(current_checksum), + 1 => ChecksumAlgorithm::Adler32(current_checksum), + _ => { + return Err(CometError::Internal( + "Unsupported checksum algorithm".to_string(), + )) + } + }; + + match compute_checksum(&mut cursor, &checksum)? { + ChecksumAlgorithm::CRC32(checksum) => checksum, + ChecksumAlgorithm::Adler32(checksum) => checksum, + } + } else { + None + }; + + let mut output_data = OpenOptions::new() + .create(true) + .append(true) + .open(output_path)?; + + output_data.write_all(&frozen)?; + + Ok((written as i64, checksum)) +} + +fn builder_to_array( + builder: &mut Box, + datatype: &DataType, + prefer_dictionary_ratio: f64, +) -> Result { + match datatype { + // We don't have redundant dictionary values which are not referenced by any key. + // So the reasonable ratio must be larger than 1.0. + DataType::Utf8 if prefer_dictionary_ratio > 1.0 => { + let builder = builder + .as_any_mut() + .downcast_mut::>() + .unwrap(); + + let dict_array = builder.finish(); + let num_keys = dict_array.keys().len(); + let num_values = dict_array.values().len(); + + if num_keys as f64 > num_values as f64 * prefer_dictionary_ratio { + // The number of keys in the dictionary is less than a ratio of the number of + // values. The dictionary is efficient, so we return it directly. + Ok(Arc::new(dict_array)) + } else { + // If the dictionary is not efficient, we convert it to a plain string array. + Ok(cast(&dict_array, &DataType::Utf8)?) + } + } + DataType::Binary if prefer_dictionary_ratio > 1.0 => { + let builder = builder + .as_any_mut() + .downcast_mut::>() + .unwrap(); + + let dict_array = builder.finish(); + let num_keys = dict_array.keys().len(); + let num_values = dict_array.values().len(); + + if num_keys as f64 > num_values as f64 * prefer_dictionary_ratio { + // The number of keys in the dictionary is less than a ratio of the number of + // values. The dictionary is efficient, so we return it directly. + Ok(Arc::new(dict_array)) + } else { + // If the dictionary is not efficient, we convert it to a plain string array. + Ok(cast(&dict_array, &DataType::Binary)?) + } + } + _ => Ok(builder.finish()), + } +} + +fn make_batch(arrays: Vec) -> RecordBatch { + let mut dict_id = 0; + let fields = arrays + .iter() + .enumerate() + .map(|(i, array)| match array.data_type() { + DataType::Dictionary(_, _) => { + let field = Field::new_dict( + format!("c{}", i), + array.data_type().clone(), + true, + dict_id, + false, + ); + dict_id += 1; + field + } + _ => Field::new(format!("c{}", i), array.data_type().clone(), true), + }) + .collect::>(); + let schema = Arc::new(Schema::new(fields)); + RecordBatch::try_new(schema, arrays).unwrap() +} diff --git a/core/src/execution/sort.rs b/core/src/execution/sort.rs new file mode 100644 index 000000000..57c9932f8 --- /dev/null +++ b/core/src/execution/sort.rs @@ -0,0 +1,222 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::{cmp, mem, ptr}; + +/// This is a copy of the `rdxsort-rs` crate, with the following changes: +/// - removed `Rdx` implementations for all types except for i64 which is the packed representation +/// of row addresses and partition ids from Spark. + +pub trait Rdx { + /// Sets the number of buckets used by the generic implementation. + fn cfg_nbuckets() -> usize; + + /// Sets the number of rounds scheduled by the generic implementation. + fn cfg_nrounds() -> usize; + + /// Returns the bucket, depending on the round. + /// + /// This should respect the radix, e.g.: + /// + /// - if the number of buckets is `2` and the type is an unsigned integer, then the result is + /// the bit starting with the least significant one. + /// - if the number of buckets is `8` and the type is an unsigned integer, then the result is + /// the byte starting with the least significant one. + /// + /// **Never** return a bucker greater or equal the number of buckets. See warning above! + fn get_bucket(&self, round: usize) -> usize; + + /// Describes the fact that the content of a bucket should be copied back in reverse order + /// after a certain round. + fn reverse(round: usize, bucket: usize) -> bool; +} + +const MASK_LONG_LOWER_40_BITS: u64 = (1u64 << 40) - 1; +const MASK_LONG_UPPER_24_BITS: u64 = !MASK_LONG_LOWER_40_BITS; + +/// `Rdx` implementation for particular i64 which represents a packed representation of row address +/// and partition id from Spark. +impl Rdx for i64 { + #[inline] + fn cfg_nbuckets() -> usize { + 16 + } + + #[inline] + fn cfg_nrounds() -> usize { + // Partition id is 3 bytes. Each byte has 2 rounds. + 6 + } + + #[inline] + fn get_bucket(&self, round: usize) -> usize { + let partition_id = (*self as u64 & MASK_LONG_UPPER_24_BITS) >> 40; + + let shift = round << 2; + ((partition_id >> shift) & 15u64) as usize + } + + #[inline] + fn reverse(_round: usize, _bucket: usize) -> bool { + false + } +} + +/// Radix Sort implementation for some type +pub trait RdxSort { + /// Execute Radix Sort, overwrites (unsorted) content of the type. + fn rdxsort(&mut self); +} + +#[inline] +fn helper_bucket(buckets_b: &mut [Vec], iter: I, cfg_nbuckets: usize, round: usize) +where + T: Rdx, + I: Iterator, +{ + for x in iter { + let b = x.get_bucket(round); + assert!( + b < cfg_nbuckets, + "Your Rdx implementation returns a bucket >= cfg_nbuckets()!" + ); + unsafe { + buckets_b.get_unchecked_mut(b).push(x); + } + } +} + +impl RdxSort for [T] +where + T: Rdx + Clone, +{ + fn rdxsort(&mut self) { + // config + let cfg_nbuckets = T::cfg_nbuckets(); + let cfg_nrounds = T::cfg_nrounds(); + + // early return + if cfg_nrounds == 0 { + return; + } + + let n = self.len(); + let presize = cmp::max(16, (n << 2) / cfg_nbuckets); // TODO: justify the presize value + let mut buckets_a: Vec> = Vec::with_capacity(cfg_nbuckets); + let mut buckets_b: Vec> = Vec::with_capacity(cfg_nbuckets); + for _ in 0..cfg_nbuckets { + buckets_a.push(Vec::with_capacity(presize)); + buckets_b.push(Vec::with_capacity(presize)); + } + + helper_bucket(&mut buckets_a, self.iter().cloned(), cfg_nbuckets, 0); + + for round in 1..cfg_nrounds { + for bucket in &mut buckets_b { + bucket.clear(); + } + for (i, bucket) in buckets_a.iter().enumerate() { + if T::reverse(round - 1, i) { + helper_bucket( + &mut buckets_b, + bucket.iter().rev().cloned(), + cfg_nbuckets, + round, + ); + } else { + helper_bucket(&mut buckets_b, bucket.iter().cloned(), cfg_nbuckets, round); + } + } + mem::swap(&mut buckets_a, &mut buckets_b); + } + + let mut pos = 0; + for (i, bucket) in buckets_a.iter_mut().enumerate() { + assert!( + pos + bucket.len() <= self.len(), + "bug: a buckets got oversized" + ); + + if T::reverse(cfg_nrounds - 1, i) { + for x in bucket.iter().rev().cloned() { + unsafe { + *self.get_unchecked_mut(pos) = x; + } + pos += 1; + } + } else { + unsafe { + ptr::copy_nonoverlapping( + bucket.as_ptr(), + self.get_unchecked_mut(pos), + bucket.len(), + ); + } + pos += bucket.len(); + } + } + + assert!(pos == self.len(), "bug: bucket size does not sum up"); + } +} + +#[cfg(test)] +mod tests { + use super::*; + + const MAXIMUM_PARTITION_ID: i32 = (1i32 << 24) - 1; + const MASK_LONG_LOWER_51_BITS: i64 = (1i64 << 51) - 1; + const MASK_LONG_UPPER_13_BITS: i64 = !MASK_LONG_LOWER_51_BITS; + const MASK_LONG_LOWER_27_BITS: i64 = (1i64 << 27) - 1; + + /// Copied from Spark class `PackedRecordPointer`. + fn pack_pointer(pointer: i64, partition_id: i32) -> i64 { + assert!(partition_id <= MAXIMUM_PARTITION_ID); + + let page_number = (pointer & MASK_LONG_UPPER_13_BITS) >> 24; + let compressed_address = page_number | (pointer & MASK_LONG_LOWER_27_BITS); + ((partition_id as i64) << 40) | compressed_address + } + + #[test] + fn test_rdxsort() { + let mut v = vec![ + pack_pointer(1, 0), + pack_pointer(2, 3), + pack_pointer(3, 2), + pack_pointer(4, 5), + pack_pointer(5, 0), + pack_pointer(6, 1), + pack_pointer(7, 3), + pack_pointer(8, 3), + ]; + v.rdxsort(); + + let expected = vec![ + pack_pointer(1, 0), + pack_pointer(5, 0), + pack_pointer(6, 1), + pack_pointer(3, 2), + pack_pointer(2, 3), + pack_pointer(7, 3), + pack_pointer(8, 3), + pack_pointer(4, 5), + ]; + + assert_eq!(v, expected); + } +} diff --git a/dev/ensure-jars-have-correct-contents.sh b/dev/ensure-jars-have-correct-contents.sh old mode 100644 new mode 100755 index d364ae1f4..1ab09a5f8 --- a/dev/ensure-jars-have-correct-contents.sh +++ b/dev/ensure-jars-have-correct-contents.sh @@ -70,6 +70,13 @@ allowed_expr+="|log4j2.properties" allowed_expr+="|comet-git-info.properties" # For some reason org/apache/spark/sql directory is also included, but with no content allowed_expr+="|^org/apache/spark/$" +# Some shuffle related classes are spark-private, e.g. TempShuffleBlockId, ShuffleWriteMetricsReporter, +# so these classes which use shuffle classes have to be in org/apache/spark. +allowed_expr+="|^org/apache/spark/shuffle/$" +allowed_expr+="|^org/apache/spark/shuffle/sort/$" +allowed_expr+="|^org/apache/spark/shuffle/sort/CometShuffleExternalSorter.*$" +allowed_expr+="|^org/apache/spark/shuffle/sort/RowPartition.class$" +allowed_expr+="|^org/apache/spark/shuffle/comet/.*$" allowed_expr+="|^org/apache/spark/sql/$" allowed_expr+="|^org/apache/spark/CometPlugin.class$" allowed_expr+="|^org/apache/spark/CometDriverPlugin.*$" diff --git a/spark/src/main/java/org/apache/spark/shuffle/comet/CometShuffleChecksumSupport.java b/spark/src/main/java/org/apache/spark/shuffle/comet/CometShuffleChecksumSupport.java new file mode 100644 index 000000000..03923d391 --- /dev/null +++ b/spark/src/main/java/org/apache/spark/shuffle/comet/CometShuffleChecksumSupport.java @@ -0,0 +1,50 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.spark.shuffle.comet; + +import org.apache.spark.SparkConf; +import org.apache.spark.internal.config.package$; + +public interface CometShuffleChecksumSupport { + long[] EMPTY_CHECKSUM_VALUE = new long[0]; + + default long[] createPartitionChecksums(int numPartitions, SparkConf conf) { + if ((boolean) conf.get(package$.MODULE$.SHUFFLE_CHECKSUM_ENABLED())) { + long[] checksum = new long[numPartitions]; + + // Initialize checksums to Long.MIN_VALUE to indicate that they have not been computed. + for (int i = 0; i < numPartitions; i++) { + checksum[i] = Long.MIN_VALUE; + } + + return checksum; + } else { + return EMPTY_CHECKSUM_VALUE; + } + } + + default String getChecksumAlgorithm(SparkConf conf) { + if ((boolean) conf.get(package$.MODULE$.SHUFFLE_CHECKSUM_ENABLED())) { + return conf.get(package$.MODULE$.SHUFFLE_CHECKSUM_ALGORITHM()); + } else { + return null; + } + } +} diff --git a/spark/src/main/java/org/apache/spark/shuffle/comet/CometShuffleMemoryAllocator.java b/spark/src/main/java/org/apache/spark/shuffle/comet/CometShuffleMemoryAllocator.java new file mode 100644 index 000000000..e8fe170e6 --- /dev/null +++ b/spark/src/main/java/org/apache/spark/shuffle/comet/CometShuffleMemoryAllocator.java @@ -0,0 +1,200 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.spark.shuffle.comet; + +import java.io.IOException; +import java.util.BitSet; + +import org.apache.spark.SparkConf; +import org.apache.spark.memory.MemoryConsumer; +import org.apache.spark.memory.MemoryMode; +import org.apache.spark.memory.SparkOutOfMemoryError; +import org.apache.spark.memory.TaskMemoryManager; +import org.apache.spark.sql.internal.SQLConf; +import org.apache.spark.unsafe.array.LongArray; +import org.apache.spark.unsafe.memory.MemoryBlock; +import org.apache.spark.unsafe.memory.UnsafeMemoryAllocator; + +import org.apache.comet.CometSparkSessionExtensions$; + +/** + * A simple memory allocator used by `CometShuffleExternalSorter` to allocate memory blocks which + * store serialized rows. We don't rely on Spark memory allocator because we need to allocate + * off-heap memory no matter memory mode is on-heap or off-heap. This allocator is configured with + * fixed size of memory, and it will throw `SparkOutOfMemoryError` if the memory is not enough. + * + *

Some methods are copied from `org.apache.spark.unsafe.memory.TaskMemoryManager` with + * modifications. Most modifications are to remove the dependency on the configured memory mode. + */ +public final class CometShuffleMemoryAllocator extends MemoryConsumer { + private final UnsafeMemoryAllocator allocator = new UnsafeMemoryAllocator(); + + private final long pageSize; + private final long totalMemory; + private long allocatedMemory = 0L; + + /** The number of bits used to address the page table. */ + private static final int PAGE_NUMBER_BITS = 13; + /** The number of entries in the page table. */ + private static final int PAGE_TABLE_SIZE = 1 << PAGE_NUMBER_BITS; + + private final MemoryBlock[] pageTable = new MemoryBlock[PAGE_TABLE_SIZE]; + private final BitSet allocatedPages = new BitSet(PAGE_TABLE_SIZE); + + private static final int OFFSET_BITS = 51; + private static final long MASK_LONG_LOWER_51_BITS = 0x7FFFFFFFFFFFFL; + + private static CometShuffleMemoryAllocator INSTANCE; + + public static synchronized CometShuffleMemoryAllocator getInstance( + SparkConf conf, TaskMemoryManager taskMemoryManager, long pageSize) { + if (INSTANCE == null) { + INSTANCE = new CometShuffleMemoryAllocator(conf, taskMemoryManager, pageSize); + } + + return INSTANCE; + } + + CometShuffleMemoryAllocator(SparkConf conf, TaskMemoryManager taskMemoryManager, long pageSize) { + super(taskMemoryManager, pageSize, MemoryMode.OFF_HEAP); + this.pageSize = pageSize; + this.totalMemory = + CometSparkSessionExtensions$.MODULE$.getCometShuffleMemorySize(conf, SQLConf.get()); + } + + public synchronized long acquireMemory(long size) { + if (allocatedMemory >= totalMemory) { + throw new SparkOutOfMemoryError( + "Unable to acquire " + + size + + " bytes of memory, current usage " + + "is " + + allocatedMemory + + " bytes and max memory is " + + totalMemory + + " bytes"); + } + long allocationSize = Math.min(size, totalMemory - allocatedMemory); + allocatedMemory += allocationSize; + return allocationSize; + } + + public long spill(long l, MemoryConsumer memoryConsumer) throws IOException { + return 0; + } + + public synchronized LongArray allocateArray(long size) { + long required = size * 8L; + MemoryBlock page = allocate(required); + return new LongArray(page); + } + + public synchronized void freeArray(LongArray array) { + if (array == null) { + return; + } + free(array.memoryBlock()); + } + + public synchronized MemoryBlock allocatePage(long required) { + long size = Math.max(pageSize, required); + return allocate(size); + } + + private synchronized MemoryBlock allocate(long required) { + if (required > TaskMemoryManager.MAXIMUM_PAGE_SIZE_BYTES) { + throw new TooLargePageException(required); + } + + long got = acquireMemory(required); + + if (got < required) { + allocatedMemory -= got; + + throw new SparkOutOfMemoryError( + "Unable to acquire " + + required + + " bytes of memory, got " + + got + + " bytes. Available: " + + (totalMemory - allocatedMemory)); + } + + int pageNumber = allocatedPages.nextClearBit(0); + if (pageNumber >= PAGE_TABLE_SIZE) { + allocatedMemory -= got; + + throw new IllegalStateException( + "Have already allocated a maximum of " + PAGE_TABLE_SIZE + " pages"); + } + + MemoryBlock block = allocator.allocate(got); + + block.pageNumber = pageNumber; + pageTable[pageNumber] = block; + allocatedPages.set(pageNumber); + + return block; + } + + public synchronized void free(MemoryBlock block) { + if (block.pageNumber == MemoryBlock.FREED_IN_ALLOCATOR_PAGE_NUMBER) { + // Already freed block + return; + } + allocatedMemory -= block.size(); + + pageTable[block.pageNumber] = null; + allocatedPages.clear(block.pageNumber); + block.pageNumber = MemoryBlock.FREED_IN_TMM_PAGE_NUMBER; + + allocator.free(block); + } + + public synchronized long getAvailableMemory() { + return totalMemory - allocatedMemory; + } + + /** + * Returns the offset in the page for the given page plus base offset address. Note that this + * method assumes that the page number is valid. + */ + public long getOffsetInPage(long pagePlusOffsetAddress) { + long offsetInPage = decodeOffset(pagePlusOffsetAddress); + int pageNumber = TaskMemoryManager.decodePageNumber(pagePlusOffsetAddress); + assert (pageNumber >= 0 && pageNumber < PAGE_TABLE_SIZE); + MemoryBlock page = pageTable[pageNumber]; + assert (page != null); + return page.getBaseOffset() + offsetInPage; + } + + public long decodeOffset(long pagePlusOffsetAddress) { + return pagePlusOffsetAddress & MASK_LONG_LOWER_51_BITS; + } + + public long encodePageNumberAndOffset(int pageNumber, long offsetInPage) { + assert (pageNumber >= 0); + return ((long) pageNumber) << OFFSET_BITS | offsetInPage & MASK_LONG_LOWER_51_BITS; + } + + public long encodePageNumberAndOffset(MemoryBlock page, long offsetInPage) { + return encodePageNumberAndOffset(page.pageNumber, offsetInPage - page.getBaseOffset()); + } +} diff --git a/spark/src/main/java/org/apache/spark/shuffle/comet/TooLargePageException.java b/spark/src/main/java/org/apache/spark/shuffle/comet/TooLargePageException.java new file mode 100644 index 000000000..5643056fc --- /dev/null +++ b/spark/src/main/java/org/apache/spark/shuffle/comet/TooLargePageException.java @@ -0,0 +1,26 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.spark.shuffle.comet; + +public class TooLargePageException extends RuntimeException { + TooLargePageException(long size) { + super("Cannot allocate a page of " + size + " bytes."); + } +} diff --git a/spark/src/main/java/org/apache/spark/shuffle/sort/CometShuffleExternalSorter.java b/spark/src/main/java/org/apache/spark/shuffle/sort/CometShuffleExternalSorter.java new file mode 100644 index 000000000..aa806e2e8 --- /dev/null +++ b/spark/src/main/java/org/apache/spark/shuffle/sort/CometShuffleExternalSorter.java @@ -0,0 +1,627 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.spark.shuffle.sort; + +import java.io.File; +import java.io.IOException; +import java.util.LinkedList; +import java.util.concurrent.*; +import javax.annotation.Nullable; + +import scala.Tuple2; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import org.apache.spark.SparkConf; +import org.apache.spark.TaskContext; +import org.apache.spark.executor.ShuffleWriteMetrics; +import org.apache.spark.memory.SparkOutOfMemoryError; +import org.apache.spark.memory.TaskMemoryManager; +import org.apache.spark.shuffle.ShuffleWriteMetricsReporter; +import org.apache.spark.shuffle.comet.CometShuffleChecksumSupport; +import org.apache.spark.shuffle.comet.CometShuffleMemoryAllocator; +import org.apache.spark.shuffle.comet.TooLargePageException; +import org.apache.spark.sql.comet.execution.shuffle.CometUnsafeShuffleWriter; +import org.apache.spark.sql.comet.execution.shuffle.ShuffleThreadPool; +import org.apache.spark.sql.comet.execution.shuffle.SpillInfo; +import org.apache.spark.sql.comet.execution.shuffle.SpillWriter; +import org.apache.spark.sql.types.StructType; +import org.apache.spark.storage.BlockManager; +import org.apache.spark.storage.TempShuffleBlockId; +import org.apache.spark.unsafe.Platform; +import org.apache.spark.unsafe.UnsafeAlignedOffset; +import org.apache.spark.unsafe.array.LongArray; +import org.apache.spark.util.Utils; + +import org.apache.comet.CometConf$; +import org.apache.comet.Native; + +/** + * An external sorter that is specialized for sort-based shuffle. + * + *

Incoming records are appended to data pages. When all records have been inserted (or when the + * current thread's shuffle memory limit is reached), the in-memory records are sorted according to + * their partition ids using native sorter. The sorted records are then written to a single output + * file (or multiple files, if we've spilled). + * + *

Unlike {@link org.apache.spark.util.collection.ExternalSorter}, this sorter does not merge its + * spill files. Instead, this merging is performed in {@link CometUnsafeShuffleWriter}, which uses a + * specialized merge procedure that avoids extra serialization/deserialization. + * + *

This sorter provides async spilling write mode. When spilling, it will submit a task to thread + * pool to write shuffle spilling file. After submitting the task, it will continue to buffer, sort + * incoming records and submit another spilling task once spilling threshold reached again or memory + * is not enough to buffer incoming records. Each spilling task will write a shuffle spilling file + * separately. After all records have been sorted and spilled, all spill files will be merged by + * {@link CometUnsafeShuffleWriter}. + */ +public final class CometShuffleExternalSorter implements CometShuffleChecksumSupport { + + private static final Logger logger = LoggerFactory.getLogger(CometShuffleExternalSorter.class); + + private final int numPartitions; + private final BlockManager blockManager; + private final TaskContext taskContext; + private final ShuffleWriteMetricsReporter writeMetrics; + + private final StructType schema; + + /** Force this sorter to spill when there are this many elements in memory. */ + private final int numElementsForSpillThreshold; + + // When this external sorter allocates memory of `sorterArray`, we need to keep its + // assigned initial size. After spilling, we will reset the array to its initial size. + // See `sorterArray` comment for more details. + private int initialSize; + + /** All sorters with memory pages used by the sorters. */ + private final ConcurrentLinkedQueue spillingSorters = new ConcurrentLinkedQueue<>(); + + private SpillSorter activeSpillSorter; + + private final LinkedList spills = new LinkedList<>(); + + /** Peak memory used by this sorter so far, in bytes. */ + private long peakMemoryUsedBytes; + + // Checksum calculator for each partition. Empty when shuffle checksum disabled. + private final long[] partitionChecksums; + + private final String checksumAlgorithm; + + // The memory allocator for this sorter. It is used to allocate/free memory pages for this sorter. + // Because we need to allocate off-heap memory regardless of configured Spark memory mode + // (on-heap/off-heap), we need a separate memory allocator. + private final CometShuffleMemoryAllocator allocator; + + /** Whether to write shuffle spilling file in async mode */ + private final boolean isAsync; + + /** Thread pool shared for async spilling write */ + private final ExecutorService threadPool; + + private final int threadNum; + + private ConcurrentLinkedQueue> asyncSpillTasks = new ConcurrentLinkedQueue<>(); + + private boolean spilling = false; + + private final int uaoSize = UnsafeAlignedOffset.getUaoSize(); + private final double preferDictionaryRatio; + + public CometShuffleExternalSorter( + TaskMemoryManager memoryManager, + BlockManager blockManager, + TaskContext taskContext, + int initialSize, + int numPartitions, + SparkConf conf, + ShuffleWriteMetricsReporter writeMetrics, + StructType schema) { + this.allocator = + CometShuffleMemoryAllocator.getInstance( + conf, + memoryManager, + Math.min(PackedRecordPointer.MAXIMUM_PAGE_SIZE_BYTES, memoryManager.pageSizeBytes())); + this.blockManager = blockManager; + this.taskContext = taskContext; + this.numPartitions = numPartitions; + this.schema = schema; + this.numElementsForSpillThreshold = + (int) CometConf$.MODULE$.COMET_EXEC_SHUFFLE_SPILL_THRESHOLD().get(); + this.writeMetrics = writeMetrics; + + this.peakMemoryUsedBytes = getMemoryUsage(); + this.partitionChecksums = createPartitionChecksums(numPartitions, conf); + this.checksumAlgorithm = getChecksumAlgorithm(conf); + + this.initialSize = initialSize; + + this.isAsync = (boolean) CometConf$.MODULE$.COMET_COLUMNAR_SHUFFLE_ASYNC_ENABLED().get(); + + if (isAsync) { + this.threadNum = (int) CometConf$.MODULE$.COMET_EXEC_SHUFFLE_ASYNC_THREAD_NUM().get(); + assert (this.threadNum > 0); + this.threadPool = ShuffleThreadPool.getThreadPool(); + } else { + this.threadNum = 0; + this.threadPool = null; + } + + this.activeSpillSorter = new SpillSorter(); + + this.preferDictionaryRatio = + (double) CometConf$.MODULE$.COMET_SHUFFLE_PREFER_DICTIONARY_RATIO().get(); + } + + public long[] getChecksums() { + return partitionChecksums; + } + + /** Sort and spill the current records in response to memory pressure. */ + public void spill() throws IOException { + if (spilling || activeSpillSorter == null || activeSpillSorter.numRecords() == 0) { + return; + } + + // In async mode, if new in-memory sorter cannot allocate required array, it triggers spill + // here. This method will initiate new sorter following normal spill logic and casue stack + // overflow eventually. So we need to avoid triggering spilling again while spilling. But + // we cannot make this as "synchronized" because it will block the caller thread. + spilling = true; + + logger.info( + "Thread {} spilling sort data of {} to disk ({} {} so far)", + Thread.currentThread().getId(), + Utils.bytesToString(getMemoryUsage()), + spills.size(), + spills.size() > 1 ? " times" : " time"); + + final Tuple2 spilledFileInfo = + blockManager.diskBlockManager().createTempShuffleBlock(); + final File file = spilledFileInfo._2(); + final TempShuffleBlockId blockId = spilledFileInfo._1(); + final SpillInfo spillInfo = new SpillInfo(numPartitions, file, blockId); + + activeSpillSorter.setSpillInfo(spillInfo); + + if (isAsync) { + SpillSorter spillingSorter = activeSpillSorter; + Callable task = + () -> { + spillingSorter.writeSortedFileNative(false); + final long spillSize = spillingSorter.freeMemory(); + spillingSorter.freeArray(); + spillingSorters.remove(spillingSorter); + + // Reset the in-memory sorter's pointer array only after freeing up the memory pages + // holding the records. Otherwise, if the task is over allocated memory, then without + // freeing the memory pages, we might not be able to get memory for the pointer array. + synchronized (CometShuffleExternalSorter.this) { + taskContext.taskMetrics().incMemoryBytesSpilled(spillSize); + } + + return null; + }; + + spillingSorters.add(spillingSorter); + asyncSpillTasks.add(threadPool.submit(task)); + + while (asyncSpillTasks.size() == threadNum) { + for (Future spillingTask : asyncSpillTasks) { + if (spillingTask.isDone()) { + asyncSpillTasks.remove(spillingTask); + break; + } + } + } + + activeSpillSorter = new SpillSorter(); + } else { + activeSpillSorter.writeSortedFileNative(false); + final long spillSize = activeSpillSorter.freeMemory(); + activeSpillSorter.reset(); + + // Reset the in-memory sorter's pointer array only after freeing up the memory pages holding + // the + // records. Otherwise, if the task is over allocated memory, then without freeing the memory + // pages, we might not be able to get memory for the pointer array. + synchronized (CometShuffleExternalSorter.this) { + taskContext.taskMetrics().incMemoryBytesSpilled(spillSize); + } + } + + spilling = false; + } + + private long getMemoryUsage() { + long totalPageSize = 0; + for (SpillSorter sorter : spillingSorters) { + totalPageSize += sorter.getMemoryUsage(); + } + return totalPageSize; + } + + private void updatePeakMemoryUsed() { + long mem = getMemoryUsage(); + if (mem > peakMemoryUsedBytes) { + peakMemoryUsedBytes = mem; + } + } + + /** Return the peak memory used so far, in bytes. */ + public long getPeakMemoryUsedBytes() { + updatePeakMemoryUsed(); + return peakMemoryUsedBytes; + } + + private long freeMemory() { + long memoryFreed = 0; + if (isAsync) { + for (SpillSorter sorter : spillingSorters) { + memoryFreed += sorter.freeMemory(); + sorter.freeArray(); + } + } + memoryFreed += activeSpillSorter.freeMemory(); + activeSpillSorter.freeArray(); + + return memoryFreed; + } + + /** Force all memory and spill files to be deleted; called by shuffle error-handling code. */ + public void cleanupResources() { + freeMemory(); + + for (SpillInfo spill : spills) { + if (spill.file.exists() && !spill.file.delete()) { + logger.error("Unable to delete spill file {}", spill.file.getPath()); + } + } + } + + /** + * Checks whether there is enough space to insert an additional record in to the sort pointer + * array and grows the array if additional space is required. If the required space cannot be + * obtained, then the in-memory data will be spilled to disk. + */ + private void growPointerArrayIfNecessary() throws IOException { + assert (activeSpillSorter != null); + if (!activeSpillSorter.hasSpaceForAnotherRecord()) { + long used = activeSpillSorter.getMemoryUsage(); + LongArray array; + try { + // could trigger spilling + array = allocator.allocateArray(used / 8 * 2); + } catch (TooLargePageException e) { + // The pointer array is too big to fix in a single page, spill. + spill(); + return; + } catch (SparkOutOfMemoryError e) { + // Cannot allocate enough memory, spill and reset pointer array. + try { + spill(); + } catch (SparkOutOfMemoryError e2) { + // Cannot allocate memory even after spilling, throw the error. + if (!activeSpillSorter.hasSpaceForAnotherRecord()) { + logger.error("Unable to grow the pointer array"); + throw e2; + } + } + return; + } + // check if spilling is triggered or not + if (activeSpillSorter.hasSpaceForAnotherRecord()) { + allocator.freeArray(array); + } else { + activeSpillSorter.expandPointerArray(array); + } + } + } + + /** + * Writes a record to the shuffle sorter. This copies the record data into this external sorter's + * managed memory, which may trigger spilling if the copy would exceed the memory limit. It + * inserts a pointer for the record and record's partition id into the in-memory sorter. + */ + public void insertRecord(Object recordBase, long recordOffset, int length, int partitionId) + throws IOException { + + assert (activeSpillSorter != null); + int threshold = numElementsForSpillThreshold; + if (activeSpillSorter.numRecords() >= threshold) { + logger.info( + "Spilling data because number of spilledRecords crossed the threshold " + threshold); + spill(); + } + + growPointerArrayIfNecessary(); + + // Need 4 or 8 bytes to store the record length. + final int required = length + uaoSize; + // Acquire enough memory to store the record. + // If we cannot acquire enough memory, we will spill current writers. + if (!activeSpillSorter.acquireNewPageIfNecessary(required)) { + // Spilling is happened, initiate new memory page for new writer. + activeSpillSorter.initialCurrentPage(required); + } + + activeSpillSorter.insertRecord(recordBase, recordOffset, length, partitionId); + } + + /** + * Close the sorter, causing any buffered data to be sorted and written out to disk. + * + * @return metadata for the spill files written by this sorter. If no records were ever inserted + * into this sorter, then this will return an empty array. + */ + public SpillInfo[] closeAndGetSpills() throws IOException { + if (activeSpillSorter != null) { + // Do not count the final file towards the spill count. + final Tuple2 spilledFileInfo = + blockManager.diskBlockManager().createTempShuffleBlock(); + final File file = spilledFileInfo._2(); + final TempShuffleBlockId blockId = spilledFileInfo._1(); + final SpillInfo spillInfo = new SpillInfo(numPartitions, file, blockId); + + // Waits for all async tasks to finish. + if (isAsync) { + for (Future task : asyncSpillTasks) { + try { + task.get(); + } catch (Exception e) { + throw new IOException(e); + } + } + + asyncSpillTasks.clear(); + } + + activeSpillSorter.setSpillInfo(spillInfo); + activeSpillSorter.writeSortedFileNative(true); + + freeMemory(); + } + + return spills.toArray(new SpillInfo[spills.size()]); + } + + class SpillSorter extends SpillWriter { + private boolean freed = false; + + private SpillInfo spillInfo; + + // These variables are reset after spilling: + @Nullable private ShuffleInMemorySorter inMemSorter; + + // This external sorter can call native code to sort partition ids and record pointers of rows. + // In order to do that, we need pass the address of the internal array in the sorter to native. + // But we cannot access it as it is private member in the Spark sorter. Instead, we allocate + // the array and assign the pointer array in the sorter. + private LongArray sorterArray; + + SpillSorter() { + this.spillInfo = null; + + this.allocator = CometShuffleExternalSorter.this.allocator; + + // Allocate array for in-memory sorter. + // As we cannot access the address of the internal array in the sorter, so we need to + // allocate the array manually and expand the pointer array in the sorter. + // We don't want in-memory sorter to allocate memory but the initial size cannot be zero. + this.inMemSorter = new ShuffleInMemorySorter(allocator, 1, true); + sorterArray = allocator.allocateArray(initialSize); + this.inMemSorter.expandPointerArray(sorterArray); + + this.allocatedPages = new LinkedList<>(); + + this.nativeLib = new Native(); + this.dataTypes = serializeSchema(schema); + } + + /** Frees allocated memory pages of this writer */ + @Override + public long freeMemory() { + // We need to synchronize here because we may get the memory usage by calling + // this method in the task thread. + synchronized (this) { + return super.freeMemory(); + } + } + + @Override + public long getMemoryUsage() { + // We need to synchronize here because we may free the memory pages in another thread, + // i.e. when spilling, but this method may be called in the task thread. + synchronized (this) { + long totalPageSize = super.getMemoryUsage(); + + if (freed) { + return totalPageSize; + } else { + return ((inMemSorter == null) ? 0 : inMemSorter.getMemoryUsage()) + totalPageSize; + } + } + } + + @Override + protected void spill(int required) throws IOException { + CometShuffleExternalSorter.this.spill(); + } + + /** Free the pointer array held by this sorter. */ + public void freeArray() { + synchronized (this) { + inMemSorter.free(); + freed = true; + } + } + + /** + * Reset the in-memory sorter's pointer array only after freeing up the memory pages holding the + * records. + */ + public void reset() { + // We allocate pointer array outside the sorter. + // So we can get array address which can be used by native code. + inMemSorter.reset(); + sorterArray = allocator.allocateArray(initialSize); + inMemSorter.expandPointerArray(sorterArray); + } + + void setSpillInfo(SpillInfo spillInfo) { + this.spillInfo = spillInfo; + } + + public int numRecords() { + return this.inMemSorter.numRecords(); + } + + public void writeSortedFileNative(boolean isLastFile) throws IOException { + // This call performs the actual sort. + long arrayAddr = this.sorterArray.getBaseOffset(); + int pos = inMemSorter.numRecords(); + nativeLib.sortRowPartitionsNative(arrayAddr, pos); + ShuffleInMemorySorter.ShuffleSorterIterator sortedRecords = + new ShuffleInMemorySorter.ShuffleSorterIterator(pos, this.sorterArray, 0); + + // If there are no sorted records, so we don't need to create an empty spill file. + if (!sortedRecords.hasNext()) { + return; + } + + final ShuffleWriteMetricsReporter writeMetricsToUse; + + if (isLastFile) { + // We're writing the final non-spill file, so we _do_ want to count this as shuffle bytes. + writeMetricsToUse = writeMetrics; + } else { + // We're spilling, so bytes written should be counted towards spill rather than write. + // Create a dummy WriteMetrics object to absorb these metrics, since we don't want to count + // them towards shuffle bytes written. + writeMetricsToUse = new ShuffleWriteMetrics(); + } + + int currentPartition = -1; + + final RowPartition rowPartition = new RowPartition(initialSize); + + while (sortedRecords.hasNext()) { + sortedRecords.loadNext(); + final int partition = sortedRecords.packedRecordPointer.getPartitionId(); + assert (partition >= currentPartition); + if (partition != currentPartition) { + // Switch to the new partition + if (currentPartition != -1) { + + if (partitionChecksums.length > 0) { + // If checksum is enabled, we need to update the checksum for the current partition. + setChecksum(partitionChecksums[currentPartition]); + setChecksumAlgo(checksumAlgorithm); + } + + long written = + doSpilling( + dataTypes, + spillInfo.file, + rowPartition, + writeMetricsToUse, + preferDictionaryRatio); + spillInfo.partitionLengths[currentPartition] = written; + + // Store the checksum for the current partition. + partitionChecksums[currentPartition] = getChecksum(); + } + currentPartition = partition; + } + + final long recordPointer = sortedRecords.packedRecordPointer.getRecordPointer(); + final long recordOffsetInPage = allocator.getOffsetInPage(recordPointer); + // Note that we need to skip over record key (partition id) + // Note that we already use off-heap memory for serialized rows, so recordPage is always + // null. + int recordSizeInBytes = UnsafeAlignedOffset.getSize(null, recordOffsetInPage) - 4; + long recordReadPosition = recordOffsetInPage + uaoSize + 4; // skip over record length too + rowPartition.addRow(recordReadPosition, recordSizeInBytes); + } + + if (currentPartition != -1) { + long written = + doSpilling( + dataTypes, spillInfo.file, rowPartition, writeMetricsToUse, preferDictionaryRatio); + spillInfo.partitionLengths[currentPartition] = written; + + synchronized (spills) { + spills.add(spillInfo); + } + } + + if (!isLastFile) { // i.e. this is a spill file + // The current semantics of `shuffleRecordsWritten` seem to be that it's updated when + // records + // are written to disk, not when they enter the shuffle sorting code. DiskBlockObjectWriter + // relies on its `recordWritten()` method being called in order to trigger periodic updates + // to + // `shuffleBytesWritten`. If we were to remove the `recordWritten()` call and increment that + // counter at a higher-level, then the in-progress metrics for records written and bytes + // written would get out of sync. + // + // When writing the last file, we pass `writeMetrics` directly to the DiskBlockObjectWriter; + // in all other cases, we pass in a dummy write metrics to capture metrics, then copy those + // metrics to the true write metrics here. The reason for performing this copying is so that + // we can avoid reporting spilled bytes as shuffle write bytes. + // + // Note that we intentionally ignore the value of `writeMetricsToUse.shuffleWriteTime()`. + // Consistent with ExternalSorter, we do not count this IO towards shuffle write time. + // SPARK-3577 tracks the spill time separately. + + // This is guaranteed to be a ShuffleWriteMetrics based on the if check in the beginning + // of this method. + synchronized (writeMetrics) { + writeMetrics.incRecordsWritten( + ((ShuffleWriteMetrics) writeMetricsToUse).recordsWritten()); + taskContext + .taskMetrics() + .incDiskBytesSpilled(((ShuffleWriteMetrics) writeMetricsToUse).bytesWritten()); + } + } + } + + public boolean hasSpaceForAnotherRecord() { + return inMemSorter.hasSpaceForAnotherRecord(); + } + + public void expandPointerArray(LongArray newArray) { + inMemSorter.expandPointerArray(newArray); + this.sorterArray = newArray; + } + + public void insertRecord(Object recordBase, long recordOffset, int length, int partitionId) { + final Object base = currentPage.getBaseObject(); + final long recordAddress = allocator.encodePageNumberAndOffset(currentPage, pageCursor); + UnsafeAlignedOffset.putSize(base, pageCursor, length); + pageCursor += uaoSize; + Platform.copyMemory(recordBase, recordOffset, base, pageCursor, length); + pageCursor += length; + inMemSorter.insertRecord(recordAddress, partitionId); + } + } +} diff --git a/spark/src/main/java/org/apache/spark/sql/comet/execution/shuffle/CometBypassMergeSortShuffleWriter.java b/spark/src/main/java/org/apache/spark/sql/comet/execution/shuffle/CometBypassMergeSortShuffleWriter.java new file mode 100644 index 000000000..5c17a643a --- /dev/null +++ b/spark/src/main/java/org/apache/spark/sql/comet/execution/shuffle/CometBypassMergeSortShuffleWriter.java @@ -0,0 +1,370 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.spark.sql.comet.execution.shuffle; + +import java.io.File; +import java.io.FileInputStream; +import java.io.IOException; +import java.io.OutputStream; +import java.nio.channels.FileChannel; +import java.util.Optional; +import java.util.concurrent.ExecutorService; + +import scala.*; +import scala.collection.Iterator; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import org.apache.spark.Partitioner; +import org.apache.spark.ShuffleDependency; +import org.apache.spark.SparkConf; +import org.apache.spark.TaskContext; +import org.apache.spark.internal.config.package$; +import org.apache.spark.memory.TaskMemoryManager; +import org.apache.spark.network.shuffle.checksum.ShuffleChecksumHelper; +import org.apache.spark.scheduler.MapStatus; +import org.apache.spark.scheduler.MapStatus$; +import org.apache.spark.serializer.SerializerInstance; +import org.apache.spark.shuffle.ShuffleWriteMetricsReporter; +import org.apache.spark.shuffle.ShuffleWriter; +import org.apache.spark.shuffle.api.ShuffleExecutorComponents; +import org.apache.spark.shuffle.api.ShuffleMapOutputWriter; +import org.apache.spark.shuffle.api.ShufflePartitionWriter; +import org.apache.spark.shuffle.api.WritableByteChannelWrapper; +import org.apache.spark.shuffle.comet.CometShuffleChecksumSupport; +import org.apache.spark.sql.catalyst.expressions.UnsafeRow; +import org.apache.spark.sql.types.StructType; +import org.apache.spark.storage.BlockManager; +import org.apache.spark.storage.FileSegment; +import org.apache.spark.storage.TempShuffleBlockId; +import org.apache.spark.util.Utils; + +import com.google.common.io.Closeables; + +import org.apache.comet.CometConf$; + +/** + * This is based on Spark `BypassMergeSortShuffleWriter`. Instead of `DiskBlockObjectWriter`, this + * writes Spark Internal Rows through `DiskBlockArrowIPCWriter` as Arrow IPC bytes. Note that Spark + * `DiskBlockObjectWriter` is general writer for any objects, but `DiskBlockArrowIPCWriter` is + * specialized for Spark Internal Rows and SQL workloads. + */ +final class CometBypassMergeSortShuffleWriter extends ShuffleWriter + implements CometShuffleChecksumSupport { + + private static final Logger logger = + LoggerFactory.getLogger(CometBypassMergeSortShuffleWriter.class); + private final int fileBufferSize; + private final boolean transferToEnabled; + private final int numPartitions; + private final BlockManager blockManager; + private final TaskMemoryManager memoryManager; + private final TaskContext taskContext; + private final SerializerInstance serializer; + + private final Partitioner partitioner; + private final ShuffleWriteMetricsReporter writeMetrics; + private final int shuffleId; + private final long mapId; + private final ShuffleExecutorComponents shuffleExecutorComponents; + private final StructType schema; + + /** Array of file writers, one for each partition */ + private CometDiskBlockWriter[] partitionWriters; + + private FileSegment[] partitionWriterSegments; + private MapStatus mapStatus; + private long[] partitionLengths; + + /** Checksum calculator for each partition. Empty when shuffle checksum disabled. */ + private final long[] partitionChecksums; + + private final boolean isAsync; + + private final int asyncThreadNum; + + /** Thread pool shared across all partition writers, for async write batch */ + private final ExecutorService threadPool; + + /** + * Are we in the process of stopping? Because map tasks can call stop() with success = true and + * then call stop() with success = false if they get an exception, we want to make sure we don't + * try deleting files, etc twice. + */ + private boolean stopping = false; + + private final SparkConf conf; + + CometBypassMergeSortShuffleWriter( + BlockManager blockManager, + TaskMemoryManager memoryManager, + TaskContext taskContext, + CometBypassMergeSortShuffleHandle handle, + long mapId, + SparkConf conf, + ShuffleWriteMetricsReporter writeMetrics, + ShuffleExecutorComponents shuffleExecutorComponents) { + // Use getSizeAsKb (not bytes) to maintain backwards compatibility if no units are provided + this.fileBufferSize = (int) (long) conf.get(package$.MODULE$.SHUFFLE_FILE_BUFFER_SIZE()) * 1024; + this.transferToEnabled = conf.getBoolean("spark.file.transferTo", true); + this.conf = conf; + this.blockManager = blockManager; + this.memoryManager = memoryManager; + this.taskContext = taskContext; + final ShuffleDependency dep = handle.dependency(); + this.mapId = mapId; + this.serializer = dep.serializer().newInstance(); + this.shuffleId = dep.shuffleId(); + this.partitioner = dep.partitioner(); + this.numPartitions = partitioner.numPartitions(); + this.writeMetrics = writeMetrics; + this.shuffleExecutorComponents = shuffleExecutorComponents; + this.schema = ((CometShuffleDependency) dep).schema().get(); + this.partitionChecksums = createPartitionChecksums(numPartitions, conf); + + this.isAsync = (boolean) CometConf$.MODULE$.COMET_COLUMNAR_SHUFFLE_ASYNC_ENABLED().get(); + this.asyncThreadNum = (int) CometConf$.MODULE$.COMET_EXEC_SHUFFLE_ASYNC_THREAD_NUM().get(); + + if (isAsync) { + logger.info("Async shuffle writer enabled"); + this.threadPool = ShuffleThreadPool.getThreadPool(); + } else { + logger.info("Async shuffle writer disabled"); + this.threadPool = null; + } + } + + @Override + public void write(Iterator> records) throws IOException { + assert (partitionWriters == null); + ShuffleMapOutputWriter mapOutputWriter = + shuffleExecutorComponents.createMapOutputWriter(shuffleId, mapId, numPartitions); + try { + if (!records.hasNext()) { + partitionLengths = + mapOutputWriter + .commitAllPartitions(ShuffleChecksumHelper.EMPTY_CHECKSUM_VALUE) + .getPartitionLengths(); + mapStatus = + MapStatus$.MODULE$.apply(blockManager.shuffleServerId(), partitionLengths, mapId); + return; + } + final long openStartTime = System.nanoTime(); + partitionWriters = new CometDiskBlockWriter[numPartitions]; + partitionWriterSegments = new FileSegment[numPartitions]; + + final String checksumAlgorithm = getChecksumAlgorithm(conf); + + // Allocate the disk writers, and open the files that we'll be writing to + for (int i = 0; i < numPartitions; i++) { + final Tuple2 tempShuffleBlockIdPlusFile = + blockManager.diskBlockManager().createTempShuffleBlock(); + final File file = tempShuffleBlockIdPlusFile._2(); + CometDiskBlockWriter writer = + new CometDiskBlockWriter( + file, + memoryManager, + taskContext, + serializer, + schema, + writeMetrics, + conf, + isAsync, + asyncThreadNum, + threadPool); + if (partitionChecksums.length > 0) { + writer.setChecksum(partitionChecksums[i]); + writer.setChecksumAlgo(checksumAlgorithm); + } + partitionWriters[i] = writer; + } + // Creating the file to write to and creating a disk writer both involve interacting with + // the disk, and can take a long time in aggregate when we open many files, so should be + // included in the shuffle write time. + writeMetrics.incWriteTime(System.nanoTime() - openStartTime); + + long outputRows = 0; + + while (records.hasNext()) { + outputRows += 1; + + final Product2 record = records.next(); + final K key = record._1(); + // Safety: `CometBypassMergeSortShuffleWriter` is only used when dealing with Comet shuffle + // dependencies, which always produce `ColumnarBatch`es. + int partition_id = partitioner.getPartition(key); + partitionWriters[partitioner.getPartition(key)].insertRow( + (UnsafeRow) record._2(), partition_id); + } + + long spillRecords = 0; + + for (int i = 0; i < numPartitions; i++) { + CometDiskBlockWriter writer = partitionWriters[i]; + partitionWriterSegments[i] = writer.close(); + + spillRecords += writer.getOutputRecords(); + } + + if (outputRows != spillRecords) { + throw new RuntimeException( + "outputRows(" + + outputRows + + ") != spillRecords(" + + spillRecords + + "). Please file a bug report."); + } + + // TODO: We probably can move checksum generation here when concatenating partition files + partitionLengths = writePartitionedData(mapOutputWriter); + mapStatus = MapStatus$.MODULE$.apply(blockManager.shuffleServerId(), partitionLengths, mapId); + } catch (Exception e) { + try { + mapOutputWriter.abort(e); + } catch (Exception e2) { + logger.error("Failed to abort the writer after failing to write map output.", e2); + e.addSuppressed(e2); + } + throw e; + } + } + + @Override + public long[] getPartitionLengths() { + return partitionLengths; + } + + /** + * Concatenate all of the per-partition files into a single combined file. + * + * @return array of lengths, in bytes, of each partition of the file (used by map output tracker). + */ + private long[] writePartitionedData(ShuffleMapOutputWriter mapOutputWriter) throws IOException { + // Track location of the partition starts in the output file + if (partitionWriters != null) { + final long writeStartTime = System.nanoTime(); + final boolean encryptionEnabled = blockManager.serializerManager().encryptionEnabled(); + + // Gets computed checksums for each partition + for (int i = 0; i < partitionChecksums.length; i++) { + partitionChecksums[i] = partitionWriters[i].getChecksum(); + } + + try { + for (int i = 0; i < numPartitions; i++) { + final File file = partitionWriterSegments[i].file(); + ShufflePartitionWriter writer = mapOutputWriter.getPartitionWriter(i); + + if (file.exists()) { + if (transferToEnabled && !encryptionEnabled) { + // Using WritableByteChannelWrapper to make resource closing consistent between + // this implementation and UnsafeShuffleWriter. + Optional maybeOutputChannel = writer.openChannelWrapper(); + if (maybeOutputChannel.isPresent()) { + writePartitionedDataWithChannel(file, maybeOutputChannel.get()); + } else { + writePartitionedDataWithStream(file, writer); + } + } else { + writePartitionedDataWithStream(file, writer); + } + if (!file.delete()) { + logger.error("Unable to delete file for partition {}", i); + } + } + } + } finally { + writeMetrics.incWriteTime(System.nanoTime() - writeStartTime); + } + partitionWriters = null; + } + + return mapOutputWriter.commitAllPartitions(partitionChecksums).getPartitionLengths(); + } + + private void writePartitionedDataWithChannel(File file, WritableByteChannelWrapper outputChannel) + throws IOException { + boolean copyThrewException = true; + try { + FileInputStream in = new FileInputStream(file); + try (FileChannel inputChannel = in.getChannel()) { + Utils.copyFileStreamNIO(inputChannel, outputChannel.channel(), 0L, inputChannel.size()); + copyThrewException = false; + } finally { + Closeables.close(in, copyThrewException); + } + } finally { + Closeables.close(outputChannel, copyThrewException); + } + } + + private void writePartitionedDataWithStream(File file, ShufflePartitionWriter writer) + throws IOException { + boolean copyThrewException = true; + FileInputStream in = new FileInputStream(file); + OutputStream outputStream; + try { + outputStream = blockManager.serializerManager().wrapForEncryption(writer.openStream()); + + try { + Utils.copyStream(in, outputStream, false, false); + copyThrewException = false; + } finally { + Closeables.close(outputStream, copyThrewException); + } + } finally { + Closeables.close(in, copyThrewException); + } + } + + @Override + public Option stop(boolean success) { + if (stopping) { + return None$.empty(); + } else { + stopping = true; + + if (success) { + if (mapStatus == null) { + throw new IllegalStateException("Cannot call stop(true) without having called write()"); + } + return Option.apply(mapStatus); + } else { + // The map task failed, so delete our output data. + if (partitionWriters != null) { + try { + for (CometDiskBlockWriter writer : partitionWriters) { + writer.freeMemory(); + + File file = writer.getFile(); + if (!file.delete()) { + logger.error("Error while deleting file {}", file.getAbsolutePath()); + } + } + } finally { + partitionWriters = null; + } + } + return None$.empty(); + } + } + } +} diff --git a/spark/src/main/java/org/apache/spark/sql/comet/execution/shuffle/CometDiskBlockWriter.java b/spark/src/main/java/org/apache/spark/sql/comet/execution/shuffle/CometDiskBlockWriter.java new file mode 100644 index 000000000..d1593f725 --- /dev/null +++ b/spark/src/main/java/org/apache/spark/sql/comet/execution/shuffle/CometDiskBlockWriter.java @@ -0,0 +1,451 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.spark.sql.comet.execution.shuffle; + +import java.io.*; +import java.util.Collections; +import java.util.Comparator; +import java.util.LinkedList; +import java.util.concurrent.ConcurrentLinkedQueue; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Future; + +import scala.reflect.ClassTag; +import scala.reflect.ClassTag$; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import org.apache.spark.SparkConf; +import org.apache.spark.TaskContext; +import org.apache.spark.executor.ShuffleWriteMetrics; +import org.apache.spark.internal.config.package$; +import org.apache.spark.memory.TaskMemoryManager; +import org.apache.spark.serializer.SerializationStream; +import org.apache.spark.serializer.SerializerInstance; +import org.apache.spark.shuffle.ShuffleWriteMetricsReporter; +import org.apache.spark.shuffle.comet.CometShuffleMemoryAllocator; +import org.apache.spark.shuffle.sort.RowPartition; +import org.apache.spark.sql.catalyst.expressions.UnsafeRow; +import org.apache.spark.sql.types.StructType; +import org.apache.spark.storage.FileSegment; +import org.apache.spark.unsafe.Platform; +import org.apache.spark.unsafe.UnsafeAlignedOffset; + +import com.google.common.annotations.VisibleForTesting; + +import org.apache.comet.CometConf$; +import org.apache.comet.Native; + +/** + * This class has similar role of Spark `DiskBlockObjectWriter` class which is used to write shuffle + * data to disk. For Comet, this is specialized to shuffle unsafe rows into disk in Arrow IPC format + * using native code. Spark `DiskBlockObjectWriter` is not a `MemoryConsumer` as it can simply + * stream the data to disk. However, for Comet, we need to buffer rows in memory page and then write + * them to disk as batches in Arrow IPC format. So, we need to extend `MemoryConsumer` to be able to + * spill the buffered rows to disk when memory pressure is high. + * + *

Similar to `CometShuffleExternalSorter`, this class also provides asynchronous spill + * mechanism. But different from `CometShuffleExternalSorter`, as a writer class for Spark + * hash-based shuffle, it writes all the rows for a partition into a single file, instead of each + * file for each spill. + */ +public final class CometDiskBlockWriter { + private static final Logger logger = LoggerFactory.getLogger(CometDiskBlockWriter.class); + private static final ClassTag OBJECT_CLASS_TAG = ClassTag$.MODULE$.Object(); + + /** List of all `NativeDiskBlockArrowIPCWriter`s of same shuffle task. */ + private static final LinkedList currentWriters = new LinkedList<>(); + + /** Queue of pending asynchronous spill tasks. */ + private ConcurrentLinkedQueue> asyncSpillTasks = new ConcurrentLinkedQueue<>(); + + /** List of `ArrowIPCWriter`s which are spilling. */ + private final LinkedList spillingWriters = new LinkedList<>(); + + private final TaskContext taskContext; + + @VisibleForTesting static final int DEFAULT_INITIAL_SER_BUFFER_SIZE = 1024 * 1024; + + // Copied from Spark `org.apache.spark.shuffle.sort.PackedRecordPointer.MAXIMUM_PAGE_SIZE_BYTES` + static final int MAXIMUM_PAGE_SIZE_BYTES = 1 << 27; + + /** The Comet allocator used to allocate pages. */ + private final CometShuffleMemoryAllocator allocator; + + /** The serializer used to write rows to memory page. */ + private final SerializerInstance serializer; + + /** The native library used to write rows to disk. */ + private final Native nativeLib; + + private final int uaoSize = UnsafeAlignedOffset.getUaoSize(); + private final StructType schema; + private final ShuffleWriteMetricsReporter writeMetrics; + private final File file; + private long totalWritten = 0L; + private boolean initialized = false; + private final int initialBufferSize; + private final boolean isAsync; + private final int asyncThreadNum; + private final ExecutorService threadPool; + private final int numElementsForSpillThreshold; + + private final double preferDictionaryRatio; + + /** The current active writer. All incoming rows will be inserted into it. */ + private ArrowIPCWriter activeWriter; + + /** A flag indicating whether we are in the process of spilling. */ + private boolean spilling = false; + + /** The buffer used to store serialized row. */ + private ExposedByteArrayOutputStream serBuffer; + + private SerializationStream serOutputStream; + + private long outputRecords = 0; + + private long insertRecords = 0; + + CometDiskBlockWriter( + File file, + TaskMemoryManager taskMemoryManager, + TaskContext taskContext, + SerializerInstance serializer, + StructType schema, + ShuffleWriteMetricsReporter writeMetrics, + SparkConf conf, + boolean isAsync, + int asyncThreadNum, + ExecutorService threadPool) { + this.allocator = + CometShuffleMemoryAllocator.getInstance( + conf, + taskMemoryManager, + Math.min(MAXIMUM_PAGE_SIZE_BYTES, taskMemoryManager.pageSizeBytes())); + this.nativeLib = new Native(); + + this.taskContext = taskContext; + this.serializer = serializer; + this.schema = schema; + this.writeMetrics = writeMetrics; + this.file = file; + this.isAsync = isAsync; + this.asyncThreadNum = asyncThreadNum; + this.threadPool = threadPool; + + this.initialBufferSize = + (int) (long) conf.get(package$.MODULE$.SHUFFLE_SORT_INIT_BUFFER_SIZE()); + + this.numElementsForSpillThreshold = + (int) CometConf$.MODULE$.COMET_EXEC_SHUFFLE_SPILL_THRESHOLD().get(); + + this.preferDictionaryRatio = + (double) CometConf$.MODULE$.COMET_SHUFFLE_PREFER_DICTIONARY_RATIO().get(); + + this.activeWriter = new ArrowIPCWriter(); + + synchronized (currentWriters) { + currentWriters.add(this); + } + } + + public void setChecksumAlgo(String checksumAlgo) { + this.activeWriter.setChecksumAlgo(checksumAlgo); + } + + public void setChecksum(long checksum) { + this.activeWriter.setChecksum(checksum); + } + + public long getChecksum() { + return this.activeWriter.getChecksum(); + } + + private void doSpill(boolean forceSync) throws IOException { + // We only allow spilling request from `NativeDiskBlockArrowIPCWriter`. + if (spilling || activeWriter.numRecords() == 0) { + return; + } + + // Set this into spilling state first, so it cannot recursively trigger another spill on itself. + spilling = true; + + if (isAsync && !forceSync) { + // Although we can continue to submit spill tasks to thread pool, buffering more rows in + // memory page will increase memory usage. So, we need to wait for at least one spilling + // task to finish. + while (asyncSpillTasks.size() == asyncThreadNum) { + for (Future task : asyncSpillTasks) { + if (task.isDone()) { + asyncSpillTasks.remove(task); + break; + } + } + } + + final ArrowIPCWriter spillingWriter = activeWriter; + activeWriter = new ArrowIPCWriter(); + + spillingWriters.add(spillingWriter); + + asyncSpillTasks.add( + threadPool.submit( + new Runnable() { + @Override + public void run() { + try { + long written = spillingWriter.doSpilling(false); + totalWritten += written; + } catch (IOException e) { + throw new RuntimeException(e); + } finally { + spillingWriter.freeMemory(); + spillingWriters.remove(spillingWriter); + } + } + }, + null)); + + } else { + // Spill in a synchronous way. + // This spill could be triggered by other thread (i.e., other `CometDiskBlockWriter`), + // so we need to synchronize it. + synchronized (CometDiskBlockWriter.this) { + totalWritten += activeWriter.doSpilling(false); + activeWriter.freeMemory(); + } + } + + spilling = false; + } + + public long getOutputRecords() { + return outputRecords; + } + + /** Serializes input row and inserts into current allocated page. */ + public void insertRow(UnsafeRow row, int partitionId) throws IOException { + insertRecords++; + + if (!initialized) { + serBuffer = new ExposedByteArrayOutputStream(DEFAULT_INITIAL_SER_BUFFER_SIZE); + serOutputStream = serializer.serializeStream(serBuffer); + + initialized = true; + } + + serBuffer.reset(); + serOutputStream.writeKey(partitionId, OBJECT_CLASS_TAG); + serOutputStream.writeValue(row, OBJECT_CLASS_TAG); + serOutputStream.flush(); + + final int serializedRecordSize = serBuffer.size(); + assert (serializedRecordSize > 0); + + // While proceeding with possible spilling and inserting the record, we need to synchronize + // it, because other threads may be spilling this writer at the same time. + synchronized (CometDiskBlockWriter.this) { + if (activeWriter.numRecords() >= numElementsForSpillThreshold) { + logger.info( + "Spilling data because number of spilledRecords crossed the threshold " + + numElementsForSpillThreshold); + // Spill the current writer + doSpill(false); + if (activeWriter.numRecords() != 0) { + throw new RuntimeException( + "activeWriter.numRecords()(" + activeWriter.numRecords() + ") != 0"); + } + } + + // Need 4 or 8 bytes to store the record length. + final int required = serializedRecordSize + uaoSize; + // Acquire enough memory to store the record. + // If we cannot acquire enough memory, we will spill current writers. + if (!activeWriter.acquireNewPageIfNecessary(required)) { + // Spilling is happened, initiate new memory page for new writer. + activeWriter.initialCurrentPage(required); + } + activeWriter.insertRecord( + serBuffer.getBuf(), Platform.BYTE_ARRAY_OFFSET, serializedRecordSize); + } + } + + FileSegment close() throws IOException { + if (isAsync) { + for (Future task : asyncSpillTasks) { + try { + task.get(); + } catch (Exception e) { + throw new RuntimeException(e); + } + } + } + + totalWritten += activeWriter.doSpilling(true); + + if (outputRecords != insertRecords) { + throw new RuntimeException( + "outputRecords(" + + outputRecords + + ") != insertRecords(" + + insertRecords + + "). Please file a bug report."); + } + + serBuffer = null; + serOutputStream = null; + + activeWriter.freeMemory(); + + synchronized (currentWriters) { + currentWriters.remove(this); + } + + return new FileSegment(file, 0, totalWritten); + } + + File getFile() { + return file; + } + + /** Returns the memory usage of active writer. */ + long getActiveMemoryUsage() { + return activeWriter.getMemoryUsage(); + } + + void freeMemory() { + for (ArrowIPCWriter writer : spillingWriters) { + writer.freeMemory(); + } + activeWriter.freeMemory(); + } + + class ArrowIPCWriter extends SpillWriter { + /** + * The list of addresses and sizes of rows buffered in memory page which wait for + * spilling/shuffle. + */ + private final RowPartition rowPartition; + + ArrowIPCWriter() { + rowPartition = new RowPartition(initialBufferSize); + + this.allocatedPages = new LinkedList<>(); + this.allocator = CometDiskBlockWriter.this.allocator; + + this.nativeLib = CometDiskBlockWriter.this.nativeLib; + this.dataTypes = serializeSchema(schema); + } + + /** Inserts a record into current allocated page. */ + void insertRecord(Object recordBase, long recordOffset, int length) { + // This `ArrowIPCWriter` could be spilled by other threads, so we need to synchronize it. + final Object base = currentPage.getBaseObject(); + + // Add row addresses + final long recordAddress = allocator.encodePageNumberAndOffset(currentPage, pageCursor); + rowPartition.addRow(allocator.getOffsetInPage(recordAddress) + uaoSize + 4, length - 4); + + // Write the record (row) size + UnsafeAlignedOffset.putSize(base, pageCursor, length); + pageCursor += uaoSize; + // Copy the record (row) data from serialized buffer to page + Platform.copyMemory(recordBase, recordOffset, base, pageCursor, length); + pageCursor += length; + } + + int numRecords() { + return rowPartition.getNumRows(); + } + + /** Spills the current in-memory records of this `ArrowIPCWriter` to disk. */ + long doSpilling(boolean isLast) throws IOException { + final ShuffleWriteMetricsReporter writeMetricsToUse; + + if (isLast) { + // We're writing the final non-spill file, so we _do_ want to count this as shuffle bytes. + writeMetricsToUse = writeMetrics; + } else { + // We're spilling, so bytes written should be counted towards spill rather than write. + // Create a dummy WriteMetrics object to absorb these metrics, since we don't want to count + // them towards shuffle bytes written. + writeMetricsToUse = new ShuffleWriteMetrics(); + } + + final long written; + + // All threads are writing to the same file, so we need to synchronize it. + synchronized (file) { + outputRecords += rowPartition.getNumRows(); + written = + doSpilling(dataTypes, file, rowPartition, writeMetricsToUse, preferDictionaryRatio); + } + + // Update metrics + // Other threads may be updating the metrics at the same time, so we need to synchronize it. + synchronized (writeMetrics) { + if (!isLast) { + writeMetrics.incRecordsWritten( + ((ShuffleWriteMetrics) writeMetricsToUse).recordsWritten()); + taskContext + .taskMetrics() + .incDiskBytesSpilled(((ShuffleWriteMetrics) writeMetricsToUse).bytesWritten()); + } + } + + return written; + } + + /** + * Spills the current in-memory records of all `NativeDiskBlockArrowIPCWriter`s until required + * memory is acquired. + */ + @Override + protected void spill(int required) throws IOException { + // Cannot allocate enough memory, spill and try again + synchronized (currentWriters) { + // Spill from the largest writer first to maximize the amount of memory we can + // acquire + Collections.sort( + currentWriters, + new Comparator() { + @Override + public int compare(CometDiskBlockWriter lhs, CometDiskBlockWriter rhs) { + long lhsMemoryUsage = lhs.getActiveMemoryUsage(); + long rhsMemoryUsage = rhs.getActiveMemoryUsage(); + return Long.compare(rhsMemoryUsage, lhsMemoryUsage); + } + }); + + for (CometDiskBlockWriter writer : currentWriters) { + // Force to spill the writer in a synchronous way, otherwise, we may not be able to + // acquire enough memory. + writer.doSpill(true); + + if (allocator.getAvailableMemory() >= required) { + break; + } + } + } + } + } +} diff --git a/spark/src/main/java/org/apache/spark/sql/comet/execution/shuffle/CometUnsafeShuffleWriter.java b/spark/src/main/java/org/apache/spark/sql/comet/execution/shuffle/CometUnsafeShuffleWriter.java new file mode 100644 index 000000000..ae38e4ffc --- /dev/null +++ b/spark/src/main/java/org/apache/spark/sql/comet/execution/shuffle/CometUnsafeShuffleWriter.java @@ -0,0 +1,573 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.spark.sql.comet.execution.shuffle; + +import java.io.FileInputStream; +import java.io.IOException; +import java.io.InputStream; +import java.io.OutputStream; +import java.nio.channels.Channels; +import java.nio.channels.FileChannel; +import java.nio.channels.WritableByteChannel; +import java.util.Arrays; +import java.util.Iterator; +import java.util.Optional; +import javax.annotation.Nullable; + +import scala.Option; +import scala.Product2; +import scala.collection.JavaConverters; +import scala.reflect.ClassTag; +import scala.reflect.ClassTag$; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import org.apache.spark.Partitioner; +import org.apache.spark.ShuffleDependency; +import org.apache.spark.SparkConf; +import org.apache.spark.TaskContext; +import org.apache.spark.annotation.Private; +import org.apache.spark.internal.config.package$; +import org.apache.spark.io.NioBufferedFileInputStream; +import org.apache.spark.memory.TaskMemoryManager; +import org.apache.spark.network.shuffle.checksum.ShuffleChecksumHelper; +import org.apache.spark.network.util.LimitedInputStream; +import org.apache.spark.scheduler.MapStatus; +import org.apache.spark.scheduler.MapStatus$; +import org.apache.spark.serializer.SerializationStream; +import org.apache.spark.serializer.SerializerInstance; +import org.apache.spark.shuffle.BaseShuffleHandle; +import org.apache.spark.shuffle.ShuffleWriteMetricsReporter; +import org.apache.spark.shuffle.ShuffleWriter; +import org.apache.spark.shuffle.api.ShuffleExecutorComponents; +import org.apache.spark.shuffle.api.ShuffleMapOutputWriter; +import org.apache.spark.shuffle.api.ShufflePartitionWriter; +import org.apache.spark.shuffle.api.SingleSpillShuffleMapOutputWriter; +import org.apache.spark.shuffle.api.WritableByteChannelWrapper; +import org.apache.spark.shuffle.sort.CometShuffleExternalSorter; +import org.apache.spark.shuffle.sort.SortShuffleManager; +import org.apache.spark.shuffle.sort.UnsafeShuffleWriter; +import org.apache.spark.sql.catalyst.expressions.UnsafeRow; +import org.apache.spark.sql.types.StructType; +import org.apache.spark.storage.BlockManager; +import org.apache.spark.storage.TimeTrackingOutputStream; +import org.apache.spark.unsafe.Platform; +import org.apache.spark.util.Utils; + +import com.google.common.annotations.VisibleForTesting; +import com.google.common.io.ByteStreams; +import com.google.common.io.Closeables; + +/** + * This is based on Spark {@link UnsafeShuffleWriter}, as a writer to write shuffling rows into + * Arrow format after sorting rows based on the partition ID. + * + *

While writing rows through this writer, it writes rows into {@link CometShuffleExternalSorter} + * which buffers rows in memory (`ShuffleInMemorySorter`). When the memory buffer is full, it will + * sort rows based on the partition ID and write them into a spill file. In Spark, sorting is + * performed by `ShuffleInMemorySorter`. In Comet, if off-heap memory is enabled, and radix sort is + * enabled, Comet will sort rows through native code. Sorting is based on a long array containing + * compacted partition IDs and row pointers. The row pointers are the addresses of the rows in the + * off-heap memory. After sorting, Comet will write the sorted rows into a spill file through native + * code (off-heap memory must be enabled). + * + *

In native code, Comet converts UnsafeRows to Arrow arrays and writes arrays into the spill + * file with Arrow IPC format. + */ +@Private +public class CometUnsafeShuffleWriter extends ShuffleWriter { + + @VisibleForTesting static final int DEFAULT_INITIAL_SER_BUFFER_SIZE = 1024 * 1024; + private static final Logger logger = LoggerFactory.getLogger(CometUnsafeShuffleWriter.class); + private static final ClassTag OBJECT_CLASS_TAG = ClassTag$.MODULE$.Object(); + private final BlockManager blockManager; + private final TaskMemoryManager memoryManager; + private final SerializerInstance serializer; + private final Partitioner partitioner; + private final ShuffleWriteMetricsReporter writeMetrics; + private final ShuffleExecutorComponents shuffleExecutorComponents; + private final int shuffleId; + private final long mapId; + private final TaskContext taskContext; + private final SparkConf sparkConf; + private final boolean transferToEnabled; + private final int initialSortBufferSize; + private final int inputBufferSizeInBytes; + private final StructType schema; + + @Nullable private MapStatus mapStatus; + @Nullable private CometShuffleExternalSorter sorter; + + @Nullable private long[] partitionLengths; + + private long peakMemoryUsedBytes = 0; + private ExposedByteArrayOutputStream serBuffer; + private SerializationStream serOutputStream; + /** + * Are we in the process of stopping? Because map tasks can call stop() with success = true and + * then call stop() with success = false if they get an exception, we want to make sure we don't + * try deleting files, etc twice. + */ + private boolean stopping = false; + + public CometUnsafeShuffleWriter( + BlockManager blockManager, + TaskMemoryManager memoryManager, + BaseShuffleHandle handle, + long mapId, + TaskContext taskContext, + SparkConf sparkConf, + ShuffleWriteMetricsReporter writeMetrics, + ShuffleExecutorComponents shuffleExecutorComponents) { + final int numPartitions = handle.dependency().partitioner().numPartitions(); + if (numPartitions > SortShuffleManager.MAX_SHUFFLE_OUTPUT_PARTITIONS_FOR_SERIALIZED_MODE()) { + throw new IllegalArgumentException( + "CometUnsafeShuffleWriter can only be used for shuffles with at most " + + SortShuffleManager.MAX_SHUFFLE_OUTPUT_PARTITIONS_FOR_SERIALIZED_MODE() + + " reduce partitions"); + } + this.blockManager = blockManager; + this.memoryManager = memoryManager; + this.mapId = mapId; + final ShuffleDependency dep = handle.dependency(); + this.shuffleId = dep.shuffleId(); + this.serializer = dep.serializer().newInstance(); + this.partitioner = dep.partitioner(); + this.schema = (StructType) ((CometShuffleDependency) dep).schema().get(); + this.writeMetrics = writeMetrics; + this.shuffleExecutorComponents = shuffleExecutorComponents; + this.taskContext = taskContext; + this.sparkConf = sparkConf; + this.transferToEnabled = sparkConf.getBoolean("spark.file.transferTo", true); + this.initialSortBufferSize = + (int) (long) sparkConf.get(package$.MODULE$.SHUFFLE_SORT_INIT_BUFFER_SIZE()); + this.inputBufferSizeInBytes = + (int) (long) sparkConf.get(package$.MODULE$.SHUFFLE_FILE_BUFFER_SIZE()) * 1024; + open(); + } + + private static OutputStream openStreamUnchecked(ShufflePartitionWriter writer) { + try { + return writer.openStream(); + } catch (IOException e) { + throw new RuntimeException(e); + } + } + + private void updatePeakMemoryUsed() { + // sorter can be null if this writer is closed + if (sorter != null) { + long mem = sorter.getPeakMemoryUsedBytes(); + if (mem > peakMemoryUsedBytes) { + peakMemoryUsedBytes = mem; + } + } + } + + /** Return the peak memory used so far, in bytes. */ + public long getPeakMemoryUsedBytes() { + updatePeakMemoryUsed(); + return peakMemoryUsedBytes; + } + + /** This convenience method should only be called in test code. */ + @VisibleForTesting + public void write(Iterator> records) throws IOException { + write(JavaConverters.asScalaIteratorConverter(records).asScala()); + } + + @Override + public void write(scala.collection.Iterator> records) throws IOException { + // Keep track of success so we know if we encountered an exception + // We do this rather than a standard try/catch/re-throw to handle + // generic throwables. + boolean success = false; + try { + while (records.hasNext()) { + insertRecordIntoSorter(records.next()); + } + closeAndWriteOutput(); + success = true; + } finally { + if (sorter != null) { + try { + sorter.cleanupResources(); + } catch (Exception e) { + // Only throw this error if we won't be masking another + // error. + if (success) { + throw e; + } else { + logger.error( + "In addition to a failure during writing, we failed during " + "cleanup.", e); + } + } + } + } + } + + private void open() { + assert (sorter == null); + sorter = + new CometShuffleExternalSorter( + memoryManager, + blockManager, + taskContext, + initialSortBufferSize, + partitioner.numPartitions(), + sparkConf, + writeMetrics, + schema); + serBuffer = new ExposedByteArrayOutputStream(DEFAULT_INITIAL_SER_BUFFER_SIZE); + serOutputStream = serializer.serializeStream(serBuffer); + } + + @VisibleForTesting + void closeAndWriteOutput() throws IOException { + assert (sorter != null); + updatePeakMemoryUsed(); + serBuffer = null; + serOutputStream = null; + final SpillInfo[] spills = sorter.closeAndGetSpills(); + try { + partitionLengths = mergeSpills(spills); + } finally { + sorter = null; + for (SpillInfo spill : spills) { + if (spill.file.exists() && !spill.file.delete()) { + logger.error("Error while deleting spill file {}", spill.file.getPath()); + } + } + } + mapStatus = MapStatus$.MODULE$.apply(blockManager.shuffleServerId(), partitionLengths, mapId); + } + + @VisibleForTesting + /** Serializes records and inserts into external sorter */ + void insertRecordIntoSorter(Product2 record) throws IOException { + assert (sorter != null); + final K key = record._1(); + final int partitionId = partitioner.getPartition(key); + serBuffer.reset(); + serOutputStream.writeKey(key, OBJECT_CLASS_TAG); + serOutputStream.writeValue((UnsafeRow) record._2(), OBJECT_CLASS_TAG); + serOutputStream.flush(); + + final int serializedRecordSize = serBuffer.size(); + assert (serializedRecordSize > 0); + + sorter.insertRecord( + serBuffer.getBuf(), Platform.BYTE_ARRAY_OFFSET, serializedRecordSize, partitionId); + } + + @VisibleForTesting + void forceSorterToSpill() throws IOException { + assert (sorter != null); + sorter.spill(); + } + + /** + * Merge zero or more spill files together, choosing the fastest merging strategy based on the + * number of spills and the IO compression codec. + * + * @return the partition lengths in the merged file. + */ + private long[] mergeSpills(SpillInfo[] spills) throws IOException { + final boolean encryptionEnabled = blockManager.serializerManager().encryptionEnabled(); + + long[] partitionLengths; + if (spills.length == 0) { + final ShuffleMapOutputWriter mapWriter = + shuffleExecutorComponents.createMapOutputWriter( + shuffleId, mapId, partitioner.numPartitions()); + return mapWriter + .commitAllPartitions(ShuffleChecksumHelper.EMPTY_CHECKSUM_VALUE) + .getPartitionLengths(); + } else if (spills.length == 1) { + Optional maybeSingleFileWriter = + shuffleExecutorComponents.createSingleFileMapOutputWriter(shuffleId, mapId); + if (maybeSingleFileWriter.isPresent() && !encryptionEnabled) { + // Comet native writer doesn't perform encryption. If encryption is enabled, we should + // perform spill merging which will perform encryption. + + // Here, we don't need to perform any metrics updates because the bytes written to this + // output file would have already been counted as shuffle bytes written. + partitionLengths = spills[0].partitionLengths; + logger.debug( + "Merge shuffle spills for mapId {} with length {}", mapId, partitionLengths.length); + maybeSingleFileWriter + .get() + .transferMapSpillFile(spills[0].file, partitionLengths, sorter.getChecksums()); + } else { + partitionLengths = mergeSpillsUsingStandardWriter(spills); + } + } else { + partitionLengths = mergeSpillsUsingStandardWriter(spills); + } + return partitionLengths; + } + + private long[] mergeSpillsUsingStandardWriter(SpillInfo[] spills) throws IOException { + long[] partitionLengths; + final boolean fastMergeEnabled = + (boolean) sparkConf.get(package$.MODULE$.SHUFFLE_UNSAFE_FAST_MERGE_ENABLE()); + final boolean encryptionEnabled = blockManager.serializerManager().encryptionEnabled(); + final ShuffleMapOutputWriter mapWriter = + shuffleExecutorComponents.createMapOutputWriter( + shuffleId, mapId, partitioner.numPartitions()); + try { + // There are multiple spills to merge, so none of these spill files' lengths were counted + // towards our shuffle write count or shuffle write time. If we use the slow merge path, + // then the final output file's size won't necessarily be equal to the sum of the spill + // files' sizes. To guard against this case, we look at the output file's actual size when + // computing shuffle bytes written. + // + // We allow the individual merge methods to report their own IO times since different merge + // strategies use different IO techniques. We count IO during merge towards the shuffle + // write time, which appears to be consistent with the "not bypassing merge-sort" branch in + // ExternalSorter. + if (fastMergeEnabled) { + // Comet native shuffle writer uses its compression codec instead of Spark's. So different + // Spark where fast spill merge is only supported when compression is disabled or when + // using a compression codec that supports concatenation of compressed streams, we can + // perform a fast spill merge that doesn't need to interpret the spilled bytes. + if (transferToEnabled && !encryptionEnabled) { + logger.debug("Using transferTo-based fast merge"); + mergeSpillsWithTransferTo(spills, mapWriter); + } else { + logger.debug("Using fileStream-based fast merge"); + mergeSpillsWithFileStream(spills, mapWriter); + } + } else { + logger.debug("Using slow merge"); + mergeSpillsWithFileStream(spills, mapWriter); + } + // When closing an UnsafeShuffleExternalSorter that has already spilled once but also has + // in-memory records, we write out the in-memory records to a file but do not count that + // final write as bytes spilled (instead, it's accounted as shuffle write). The merge needs + // to be counted as shuffle write, but this will lead to double-counting of the final + // SpillInfo's bytes. + writeMetrics.decBytesWritten(spills[spills.length - 1].file.length()); + partitionLengths = mapWriter.commitAllPartitions(sorter.getChecksums()).getPartitionLengths(); + } catch (Exception e) { + try { + mapWriter.abort(e); + } catch (Exception e2) { + logger.warn("Failed to abort writing the map output.", e2); + e.addSuppressed(e2); + } + throw e; + } + return partitionLengths; + } + + /** + * Merges spill files using Java FileStreams. This code path is typically slower than the + * NIO-based merge, {@link CometUnsafeShuffleWriter#mergeSpillsWithTransferTo(SpillInfo[], + * ShuffleMapOutputWriter)}, and it's mostly used in cases where the IO compression codec does not + * support concatenation of compressed data, when encryption is enabled, or when users have + * explicitly disabled use of {@code transferTo} in order to work around kernel bugs. This code + * path might also be faster in cases where individual partition size in a spill is small and + * UnsafeShuffleWriter#mergeSpillsWithTransferTo method performs many small disk ios which is + * inefficient. In those case, Using large buffers for input and output files helps reducing the + * number of disk ios, making the file merging faster. + * + * @param spills the spills to merge. + * @param mapWriter the map output writer to use for output. + * @return the partition lengths in the merged file. + */ + private void mergeSpillsWithFileStream(SpillInfo[] spills, ShuffleMapOutputWriter mapWriter) + throws IOException { + logger.debug("Merge shuffle spills with FileStream for mapId {}", mapId); + final int numPartitions = partitioner.numPartitions(); + final InputStream[] spillInputStreams = new InputStream[spills.length]; + + boolean threwException = true; + try { + for (int i = 0; i < spills.length; i++) { + spillInputStreams[i] = + new NioBufferedFileInputStream(spills[i].file, inputBufferSizeInBytes); + // Only convert the partitionLengths when debug level is enabled. + if (logger.isDebugEnabled()) { + logger.debug( + "Partition lengths for mapId {} in Spill {}: {}", + mapId, + i, + Arrays.toString(spills[i].partitionLengths)); + } + } + for (int partition = 0; partition < numPartitions; partition++) { + boolean copyThrewException = true; + ShufflePartitionWriter writer = mapWriter.getPartitionWriter(partition); + OutputStream partitionOutput = writer.openStream(); + try { + partitionOutput = new TimeTrackingOutputStream(writeMetrics, partitionOutput); + partitionOutput = blockManager.serializerManager().wrapForEncryption(partitionOutput); + for (int i = 0; i < spills.length; i++) { + final long partitionLengthInSpill = spills[i].partitionLengths[partition]; + + if (partitionLengthInSpill > 0) { + InputStream partitionInputStream = null; + boolean copySpillThrewException = true; + try { + partitionInputStream = + new LimitedInputStream(spillInputStreams[i], partitionLengthInSpill, false); + ByteStreams.copy(partitionInputStream, partitionOutput); + copySpillThrewException = false; + } finally { + Closeables.close(partitionInputStream, copySpillThrewException); + } + } + } + copyThrewException = false; + } finally { + Closeables.close(partitionOutput, copyThrewException); + } + long numBytesWritten = writer.getNumBytesWritten(); + writeMetrics.incBytesWritten(numBytesWritten); + } + threwException = false; + } finally { + // To avoid masking exceptions that caused us to prematurely enter the finally block, only + // throw exceptions during cleanup if threwException == false. + for (InputStream stream : spillInputStreams) { + Closeables.close(stream, threwException); + } + } + } + + /** + * Merges spill files by using NIO's transferTo to concatenate spill partitions' bytes. This is + * only safe when the IO compression codec and serializer support concatenation of serialized + * streams. + * + * @param spills the spills to merge. + * @param mapWriter the map output writer to use for output. + * @return the partition lengths in the merged file. + */ + private void mergeSpillsWithTransferTo(SpillInfo[] spills, ShuffleMapOutputWriter mapWriter) + throws IOException { + logger.debug("Merge shuffle spills with TransferTo for mapId {}", mapId); + final int numPartitions = partitioner.numPartitions(); + final FileChannel[] spillInputChannels = new FileChannel[spills.length]; + final long[] spillInputChannelPositions = new long[spills.length]; + + boolean threwException = true; + try { + for (int i = 0; i < spills.length; i++) { + spillInputChannels[i] = new FileInputStream(spills[i].file).getChannel(); + // Only convert the partitionLengths when debug level is enabled. + if (logger.isDebugEnabled()) { + logger.debug( + "Partition lengths for mapId {} in Spill {}: {}", + mapId, + i, + Arrays.toString(spills[i].partitionLengths)); + } + } + for (int partition = 0; partition < numPartitions; partition++) { + boolean copyThrewException = true; + ShufflePartitionWriter writer = mapWriter.getPartitionWriter(partition); + WritableByteChannelWrapper resolvedChannel = + writer + .openChannelWrapper() + .orElseGet(() -> new StreamFallbackChannelWrapper(openStreamUnchecked(writer))); + try { + for (int i = 0; i < spills.length; i++) { + long partitionLengthInSpill = spills[i].partitionLengths[partition]; + final FileChannel spillInputChannel = spillInputChannels[i]; + final long writeStartTime = System.nanoTime(); + Utils.copyFileStreamNIO( + spillInputChannel, + resolvedChannel.channel(), + spillInputChannelPositions[i], + partitionLengthInSpill); + copyThrewException = false; + spillInputChannelPositions[i] += partitionLengthInSpill; + writeMetrics.incWriteTime(System.nanoTime() - writeStartTime); + } + } finally { + Closeables.close(resolvedChannel, copyThrewException); + } + long numBytes = writer.getNumBytesWritten(); + writeMetrics.incBytesWritten(numBytes); + } + threwException = false; + } finally { + // To avoid masking exceptions that caused us to prematurely enter the finally block, only + // throw exceptions during cleanup if threwException == false. + for (int i = 0; i < spills.length; i++) { + assert (spillInputChannelPositions[i] == spills[i].file.length()); + Closeables.close(spillInputChannels[i], threwException); + } + } + } + + @Override + public Option stop(boolean success) { + try { + taskContext.taskMetrics().incPeakExecutionMemory(getPeakMemoryUsedBytes()); + + if (stopping) { + return Option.apply(null); + } else { + stopping = true; + if (success) { + if (mapStatus == null) { + throw new IllegalStateException("Cannot call stop(true) without having called write()"); + } + return Option.apply(mapStatus); + } else { + return Option.apply(null); + } + } + } finally { + if (sorter != null) { + // If sorter is non-null, then this implies that we called stop() in response to an error, + // so we need to clean up memory and spill files created by the sorter + sorter.cleanupResources(); + } + } + } + + @Override + public long[] getPartitionLengths() { + return new long[0]; + } + + private static final class StreamFallbackChannelWrapper implements WritableByteChannelWrapper { + private final WritableByteChannel channel; + + StreamFallbackChannelWrapper(OutputStream fallbackStream) { + this.channel = Channels.newChannel(fallbackStream); + } + + @Override + public WritableByteChannel channel() { + return channel; + } + + @Override + public void close() throws IOException { + channel.close(); + } + } +} diff --git a/spark/src/main/java/org/apache/spark/sql/comet/execution/shuffle/ExposedByteArrayOutputStream.java b/spark/src/main/java/org/apache/spark/sql/comet/execution/shuffle/ExposedByteArrayOutputStream.java new file mode 100644 index 000000000..8725af577 --- /dev/null +++ b/spark/src/main/java/org/apache/spark/sql/comet/execution/shuffle/ExposedByteArrayOutputStream.java @@ -0,0 +1,33 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.spark.sql.comet.execution.shuffle; + +import java.io.ByteArrayOutputStream; + +/** Subclass of ByteArrayOutputStream that exposes `buf` directly. */ +public final class ExposedByteArrayOutputStream extends ByteArrayOutputStream { + ExposedByteArrayOutputStream(int size) { + super(size); + } + + public byte[] getBuf() { + return buf; + } +} diff --git a/spark/src/main/java/org/apache/spark/sql/comet/execution/shuffle/ShuffleThreadPool.java b/spark/src/main/java/org/apache/spark/sql/comet/execution/shuffle/ShuffleThreadPool.java new file mode 100644 index 000000000..69550e47b --- /dev/null +++ b/spark/src/main/java/org/apache/spark/sql/comet/execution/shuffle/ShuffleThreadPool.java @@ -0,0 +1,70 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.spark.sql.comet.execution.shuffle; + +import java.util.concurrent.*; + +import com.google.common.util.concurrent.ThreadFactoryBuilder; + +import org.apache.comet.CometConf$; + +public class ShuffleThreadPool { + private static ThreadPoolExecutor INSTANCE; + + /** Get the thread pool instance for shuffle writer. */ + public static synchronized ExecutorService getThreadPool() { + if (INSTANCE == null) { + boolean isAsync = (boolean) CometConf$.MODULE$.COMET_COLUMNAR_SHUFFLE_ASYNC_ENABLED().get(); + + if (isAsync) { + ThreadFactory factory = + new ThreadFactoryBuilder().setNameFormat("async-shuffle-writer-%d").build(); + + int threadNum = (int) CometConf$.MODULE$.COMET_EXEC_SHUFFLE_ASYNC_MAX_THREAD_NUM().get(); + INSTANCE = + new ThreadPoolExecutor( + 0, threadNum, 1L, TimeUnit.SECONDS, new ThreadPoolQueue(threadNum), factory); + } + } + + return INSTANCE; + } +} + +/** + * A blocking queue for thread pool. This will block new task submission until there is space in the + * queue. + */ +final class ThreadPoolQueue extends ArrayBlockingQueue { + public ThreadPoolQueue(int capacity) { + super(capacity); + } + + @Override + public boolean offer(Runnable e) { + try { + put(e); + } catch (InterruptedException e1) { + Thread.currentThread().interrupt(); + return false; + } + return true; + } +} diff --git a/spark/src/main/java/org/apache/spark/sql/comet/execution/shuffle/SpillInfo.java b/spark/src/main/java/org/apache/spark/sql/comet/execution/shuffle/SpillInfo.java new file mode 100644 index 000000000..210937b58 --- /dev/null +++ b/spark/src/main/java/org/apache/spark/sql/comet/execution/shuffle/SpillInfo.java @@ -0,0 +1,37 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.spark.sql.comet.execution.shuffle; + +import java.io.File; + +import org.apache.spark.storage.TempShuffleBlockId; + +/** Metadata for a block of data written by ShuffleExternalSorter. */ +public final class SpillInfo { + public final long[] partitionLengths; + public final File file; + final TempShuffleBlockId blockId; + + public SpillInfo(int numPartitions, File file, TempShuffleBlockId blockId) { + this.partitionLengths = new long[numPartitions]; + this.file = file; + this.blockId = blockId; + } +} diff --git a/spark/src/main/java/org/apache/spark/sql/comet/execution/shuffle/SpillWriter.java b/spark/src/main/java/org/apache/spark/sql/comet/execution/shuffle/SpillWriter.java new file mode 100644 index 000000000..e6f973f83 --- /dev/null +++ b/spark/src/main/java/org/apache/spark/sql/comet/execution/shuffle/SpillWriter.java @@ -0,0 +1,233 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.spark.sql.comet.execution.shuffle; + +import java.io.ByteArrayOutputStream; +import java.io.File; +import java.io.IOException; +import java.util.LinkedList; +import java.util.Locale; +import javax.annotation.Nullable; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import org.apache.spark.memory.SparkOutOfMemoryError; +import org.apache.spark.shuffle.ShuffleWriteMetricsReporter; +import org.apache.spark.shuffle.comet.CometShuffleMemoryAllocator; +import org.apache.spark.shuffle.sort.RowPartition; +import org.apache.spark.sql.types.StructType; +import org.apache.spark.unsafe.memory.MemoryBlock; + +import org.apache.comet.Native; +import org.apache.comet.serde.QueryPlanSerde$; + +/** + * The interface for writing records into disk. This interface takes input rows and stores them in + * allocated memory pages. When certain condition is met, the writer will spill the content of + * memory to disk. + */ +public abstract class SpillWriter { + private static final Logger logger = LoggerFactory.getLogger(SpillWriter.class); + + /** + * Memory pages that hold the records being sorted. The pages in this list are freed when + * spilling, although in principle we could recycle these pages across spills (on the other hand, + * this might not be necessary if we maintained a pool of re-usable pages in the TaskMemoryManager + * itself). + */ + protected LinkedList allocatedPages; + + @Nullable protected MemoryBlock currentPage = null; + protected long pageCursor = -1; + + // The memory allocator for this sorter. It is used to allocate/free memory pages for this sorter. + // Because we need to allocate off-heap memory regardless of configured Spark memory mode + // (on-heap/off-heap), we need a separate memory allocator. + protected CometShuffleMemoryAllocator allocator; + + protected Native nativeLib; + + protected byte[][] dataTypes; + + // 0: CRC32, 1: Adler32. Spark uses Adler32 by default. + protected int checksumAlgo = 1; + protected long checksum = -1; + + /** Serialize row schema to byte array. */ + protected byte[][] serializeSchema(StructType schema) { + byte[][] dataTypes = new byte[schema.length()][]; + for (int i = 0; i < schema.length(); i++) { + ByteArrayOutputStream outputStream = new ByteArrayOutputStream(); + try { + QueryPlanSerde$.MODULE$ + .serializeDataType(schema.apply(i).dataType()) + .get() + .writeTo(outputStream); + } catch (IOException e) { + throw new RuntimeException(e); + } + dataTypes[i] = outputStream.toByteArray(); + } + + return dataTypes; + } + + protected void setChecksumAlgo(String checksumAlgo) { + String algo = checksumAlgo.toLowerCase(Locale.ROOT); + + if (algo.equals("crc32")) { + this.checksumAlgo = 0; + } else if (algo.equals("adler32")) { + this.checksumAlgo = 1; + } else { + throw new UnsupportedOperationException( + "Unsupported shuffle checksum algorithm: " + checksumAlgo); + } + } + + protected void setChecksum(long checksum) { + this.checksum = checksum; + } + + protected long getChecksum() { + return checksum; + } + + /** + * Spills the current in-memory records to disk. + * + * @param required the amount of required memory. + */ + protected abstract void spill(int required) throws IOException; + + /** + * Allocates more memory in order to insert an additional record. This will request additional + * memory from the memory manager and spill if the requested memory can not be obtained. + * + * @param required the required space in the data page, in bytes, including space for storing the + * record size. This must be less than or equal to the page size (records that exceed the page + * size are handled via a different code path which uses special overflow pages). + * @return true if the memory is allocated successfully, false otherwise. If false is returned, it + * means spilling is happening and the caller should not continue insert into this writer. + */ + public boolean acquireNewPageIfNecessary(int required) { + if (currentPage == null + || pageCursor + required > currentPage.getBaseOffset() + currentPage.size()) { + // TODO: try to find space in previous pages + try { + currentPage = allocator.allocatePage(required); + } catch (SparkOutOfMemoryError error) { + try { + // Cannot allocate enough memory, spill + spill(required); + return false; + } catch (IOException e) { + throw new RuntimeException("Unable to spill() in order to acquire " + required, e); + } + } + assert (currentPage != null); + pageCursor = currentPage.getBaseOffset(); + allocatedPages.add(currentPage); + } + return true; + } + + /** Allocates initial memory page */ + public void initialCurrentPage(int required) { + assert (currentPage == null); + try { + currentPage = allocator.allocatePage(required); + } catch (SparkOutOfMemoryError e) { + logger.error("Unable to acquire {} bytes of memory", required); + throw e; + } + assert (currentPage != null); + pageCursor = currentPage.getBaseOffset(); + allocatedPages.add(currentPage); + } + + /** The core logic of spilling buffered rows into disk. */ + protected long doSpilling( + byte[][] dataTypes, + File file, + RowPartition rowPartition, + ShuffleWriteMetricsReporter writeMetricsToUse, + double preferDictionaryRatio) { + long[] addresses = rowPartition.getRowAddresses(); + int[] sizes = rowPartition.getRowSizes(); + + boolean checksumEnabled = checksum != -1; + long currentChecksum = checksumEnabled ? checksum : 0L; + + long start = System.nanoTime(); + long[] results = + nativeLib.writeSortedFileNative( + addresses, + sizes, + dataTypes, + file.getAbsolutePath(), + preferDictionaryRatio, + checksumEnabled, + checksumAlgo, + currentChecksum); + + long written = results[0]; + checksum = results[1]; + + rowPartition.reset(); + + // Update metrics + // Other threads may be updating the metrics at the same time, so we need to synchronize it. + synchronized (writeMetricsToUse) { + writeMetricsToUse.incWriteTime(System.nanoTime() - start); + writeMetricsToUse.incRecordsWritten(addresses.length); + writeMetricsToUse.incBytesWritten(written); + } + + return written; + } + + /** Frees allocated memory pages of this writer */ + public long freeMemory() { + long freed = 0L; + for (MemoryBlock block : allocatedPages) { + freed += block.size(); + allocator.free(block); + } + allocatedPages.clear(); + currentPage = null; + pageCursor = 0; + + return freed; + } + + /** Returns the amount of memory used by this writer, in bytes. */ + public long getMemoryUsage() { + // Assume this method won't be called on a spilling writer, so we don't need to synchronize it. + long used = 0; + + for (MemoryBlock block : allocatedPages) { + used += block.size(); + } + + return used; + } +} diff --git a/spark/src/main/scala/org/apache/comet/CometSparkSessionExtensions.scala b/spark/src/main/scala/org/apache/comet/CometSparkSessionExtensions.scala index cd0b830b2..69d1fb367 100644 --- a/spark/src/main/scala/org/apache/comet/CometSparkSessionExtensions.scala +++ b/spark/src/main/scala/org/apache/comet/CometSparkSessionExtensions.scala @@ -29,17 +29,20 @@ import org.apache.spark.sql.SparkSessionExtensions import org.apache.spark.sql.catalyst.expressions.AttributeReference import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.comet._ +import org.apache.spark.sql.comet.execution.shuffle.{CometColumnarShuffle, CometNativeShuffle} +import org.apache.spark.sql.comet.execution.shuffle.CometShuffleExchangeExec import org.apache.spark.sql.execution._ import org.apache.spark.sql.execution.aggregate.HashAggregateExec import org.apache.spark.sql.execution.datasources._ import org.apache.spark.sql.execution.datasources.parquet.ParquetFileFormat import org.apache.spark.sql.execution.datasources.v2.BatchScanExec import org.apache.spark.sql.execution.datasources.v2.parquet.ParquetScan +import org.apache.spark.sql.execution.exchange.ShuffleExchangeExec import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ import org.apache.comet.CometConf._ -import org.apache.comet.CometSparkSessionExtensions.{isANSIEnabled, isCometEnabled, isCometExecEnabled, isCometOperatorEnabled, isCometScan, isCometScanEnabled, isSchemaSupported} +import org.apache.comet.CometSparkSessionExtensions.{isANSIEnabled, isCometColumnarShuffleEnabled, isCometEnabled, isCometExecEnabled, isCometOperatorEnabled, isCometScan, isCometScanEnabled, isCometShuffleEnabled, isSchemaSupported} import org.apache.comet.parquet.{CometParquetScan, SupportsComet} import org.apache.comet.serde.OperatorOuterClass.Operator import org.apache.comet.serde.QueryPlanSerde @@ -125,6 +128,31 @@ class CometSparkSessionExtensions } case class CometExecRule(session: SparkSession) extends Rule[SparkPlan] { + private def applyCometShuffle(plan: SparkPlan): SparkPlan = { + plan.transformUp { + case s: ShuffleExchangeExec + if isCometPlan(s.child) && !isCometColumnarShuffleEnabled(conf) && + QueryPlanSerde.supportPartitioning(s.child.output, s.outputPartitioning) => + logInfo("Comet extension enabled for Native Shuffle") + + // Switch to use Decimal128 regardless of precision, since Arrow native execution + // doesn't support Decimal32 and Decimal64 yet. + conf.setConfString(CometConf.COMET_USE_DECIMAL_128.key, "true") + CometShuffleExchangeExec(s, shuffleType = CometNativeShuffle) + + // Columnar shuffle for regular Spark operators (not Comet) and Comet operators + // (if configured) + case s: ShuffleExchangeExec + if (!s.child.supportsColumnar || isCometPlan( + s.child)) && isCometColumnarShuffleEnabled(conf) && + QueryPlanSerde.supportPartitioningTypes(s.child.output) => + logInfo("Comet extension enabled for JVM Columnar Shuffle") + CometShuffleExchangeExec(s, shuffleType = CometColumnarShuffle) + } + } + + private def isCometPlan(op: SparkPlan): Boolean = op.isInstanceOf[CometPlan] + private def isCometNative(op: SparkPlan): Boolean = op.isInstanceOf[CometNativeExec] // spotless:off @@ -137,7 +165,7 @@ class CometSparkSessionExtensions * 1. The child(ren) of the current node `p` cannot be converted to native * In this case, we'll simply return the original Spark plan, since Comet native * execution cannot start from an arbitrary Spark operator (unless it is special node - * such as scan or sink such as union etc., which are wrapped by + * such as scan or sink such as shuffle exchange, union etc., which are wrapped by * `CometScanWrapper` and `CometSinkPlaceHolder` respectively). * * 2. The child(ren) of the current node `p` can be converted to native @@ -295,7 +323,46 @@ class CometSparkSessionExtensions case Some(nativeOp) => val cometOp = CometUnionExec(u, u.children) CometSinkPlaceHolder(nativeOp, u, cometOp) - case None => u + } + + // Native shuffle for Comet operators + case s: ShuffleExchangeExec + if isCometShuffleEnabled(conf) && + !isCometColumnarShuffleEnabled(conf) && + QueryPlanSerde.supportPartitioning(s.child.output, s.outputPartitioning) => + logInfo("Comet extension enabled for Native Shuffle") + + val newOp = transform1(s) + newOp match { + case Some(nativeOp) => + // Switch to use Decimal128 regardless of precision, since Arrow native execution + // doesn't support Decimal32 and Decimal64 yet. + conf.setConfString(CometConf.COMET_USE_DECIMAL_128.key, "true") + val cometOp = CometShuffleExchangeExec(s, shuffleType = CometNativeShuffle) + CometSinkPlaceHolder(nativeOp, s, cometOp) + case None => + s + } + + // Columnar shuffle for regular Spark operators (not Comet) and Comet operators + // (if configured) + case s: ShuffleExchangeExec + if isCometShuffleEnabled(conf) && isCometColumnarShuffleEnabled(conf) && + QueryPlanSerde.supportPartitioningTypes(s.child.output) => + logInfo("Comet extension enabled for JVM Columnar Shuffle") + + val newOp = QueryPlanSerde.operator2Proto(s) + newOp match { + case Some(nativeOp) => + s.child match { + case n if n.isInstanceOf[CometNativeExec] || !n.supportsColumnar => + val cometOp = CometShuffleExchangeExec(s, shuffleType = CometColumnarShuffle) + CometSinkPlaceHolder(nativeOp, s, cometOp) + case _ => + s + } + case None => + s } case op => @@ -316,8 +383,12 @@ class CometSparkSessionExtensions if (!isCometEnabled(conf)) return plan if (!isCometExecEnabled(conf)) { - // Comet exec is disabled - plan + // Comet exec is disabled, but for Spark shuffle, we still can use Comet columnar shuffle + if (isCometShuffleEnabled(conf)) { + applyCometShuffle(plan) + } else { + plan + } } else { var newPlan = transform(plan) @@ -405,6 +476,11 @@ object CometSparkSessionExtensions extends Logging { conf.getConfString(operatorFlag, "false").toBoolean || isCometAllOperatorEnabled(conf) } + private[comet] def isCometShuffleEnabled(conf: SQLConf): Boolean = + COMET_EXEC_SHUFFLE_ENABLED.get(conf) && + (conf.contains("spark.shuffle.manager") && conf.getConfString("spark.shuffle.manager") == + "org.apache.spark.sql.comet.execution.shuffle.CometShuffleManager") + private[comet] def isCometScanEnabled(conf: SQLConf): Boolean = { COMET_SCAN_ENABLED.get(conf) } @@ -413,6 +489,10 @@ object CometSparkSessionExtensions extends Logging { COMET_EXEC_ENABLED.get(conf) } + private[comet] def isCometColumnarShuffleEnabled(conf: SQLConf): Boolean = { + COMET_COLUMNAR_SHUFFLE_ENABLED.get(conf) + } + private[comet] def isCometAllOperatorEnabled(conf: SQLConf): Boolean = { COMET_EXEC_ALL_OPERATOR_ENABLED.get(conf) } @@ -470,4 +550,23 @@ object CometSparkSessionExtensions extends Logging { def getCometMemoryOverhead(sparkConf: SparkConf): Long = { ByteUnit.MiB.toBytes(getCometMemoryOverheadInMiB(sparkConf)) } + + /** Calculates required shuffle memory size in bytes per executor process for Comet. */ + def getCometShuffleMemorySize(sparkConf: SparkConf, conf: SQLConf): Long = { + val cometMemoryOverhead = getCometMemoryOverheadInMiB(sparkConf) + + val overheadFactor = COMET_COLUMNAR_SHUFFLE_MEMORY_FACTOR.get(conf) + val cometShuffleMemoryFromConf = COMET_COLUMNAR_SHUFFLE_MEMORY_SIZE.get(conf) + + val shuffleMemorySize = + cometShuffleMemoryFromConf.getOrElse((overheadFactor * cometMemoryOverhead).toLong) + if (shuffleMemorySize > cometMemoryOverhead) { + logWarning( + s"Configured shuffle memory size $shuffleMemorySize is larger than Comet memory overhead " + + s"$cometMemoryOverhead, using Comet memory overhead instead.") + ByteUnit.MiB.toBytes(cometMemoryOverhead) + } else { + ByteUnit.MiB.toBytes(shuffleMemorySize) + } + } } diff --git a/spark/src/main/scala/org/apache/comet/Native.scala b/spark/src/main/scala/org/apache/comet/Native.scala index c930c7d76..05bada522 100644 --- a/spark/src/main/scala/org/apache/comet/Native.scala +++ b/spark/src/main/scala/org/apache/comet/Native.scala @@ -104,4 +104,49 @@ class Native extends NativeBase { * the address to native query plan. */ @native def releasePlan(plan: Long): Unit + + /** + * Used by Comet shuffle external sorter to write sorted records to disk. + * + * @param addresses + * the array of addresses of Spark unsafe rows. + * @param rowSizes + * the row sizes of Spark unsafe rows. + * @param datatypes + * the datatypes of fields in Spark unsafe rows. + * @param file + * the file path to write to. + * @param preferDictionaryRatio + * the ratio of total values to distinct values in a string column that makes the writer to + * prefer dictionary encoding. If it is larger than the specified ratio, dictionary encoding + * will be used when writing columns of string type. + * @param checksumEnabled + * whether to compute checksum of written file. + * @param checksumAlgo + * the checksum algorithm to use. 0 for CRC32, 1 for Adler32. + * @param currentChecksum + * the current checksum of the file. As the checksum is computed incrementally, this is used + * to resume the computation of checksum for previous written data. + * @return + * [the number of bytes written to disk, the checksum] + */ + @native def writeSortedFileNative( + addresses: Array[Long], + rowSizes: Array[Int], + datatypes: Array[Array[Byte]], + file: String, + preferDictionaryRatio: Double, + checksumEnabled: Boolean, + checksumAlgo: Int, + currentChecksum: Long): Array[Long] + + /** + * Sorts partition ids of Spark unsafe rows in place. Used by Comet shuffle external sorter. + * + * @param addr + * the address of the array of compacted partition ids. + * @param size + * the size of the array. + */ + @native def sortRowPartitionsNative(addr: Long, size: Long): Unit } diff --git a/spark/src/main/scala/org/apache/comet/shims/ShimCometShuffleExchangeExec.scala b/spark/src/main/scala/org/apache/comet/shims/ShimCometShuffleExchangeExec.scala new file mode 100644 index 000000000..f89dbb8db --- /dev/null +++ b/spark/src/main/scala/org/apache/comet/shims/ShimCometShuffleExchangeExec.scala @@ -0,0 +1,39 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.comet.shims + +import org.apache.spark.sql.comet.execution.shuffle.{CometShuffleExchangeExec, ShuffleType} +import org.apache.spark.sql.execution.exchange.ShuffleExchangeExec + +trait ShimCometShuffleExchangeExec { + // TODO: remove after dropping Spark 3.2 and 3.3 support + def apply(s: ShuffleExchangeExec, shuffleType: ShuffleType): CometShuffleExchangeExec = { + val advisoryPartitionSize = s.getClass.getDeclaredMethods + .filter(_.getName == "advisoryPartitionSize") + .flatMap(_.invoke(s).asInstanceOf[Option[Long]]) + .headOption + CometShuffleExchangeExec( + s.outputPartitioning, + s.child, + s.shuffleOrigin, + shuffleType, + advisoryPartitionSize) + } +} diff --git a/spark/src/main/scala/org/apache/spark/Plugins.scala b/spark/src/main/scala/org/apache/spark/Plugins.scala index 113e3a47e..97838448a 100644 --- a/spark/src/main/scala/org/apache/spark/Plugins.scala +++ b/spark/src/main/scala/org/apache/spark/Plugins.scala @@ -90,11 +90,13 @@ class CometDriverPlugin extends DriverPlugin with Logging { /** * Whether we should override Spark memory configuration for Comet. This only returns true when - * Comet native execution is enabled + * Comet native execution is enabled and/or Comet shuffle is enabled */ private def shouldOverrideMemoryConf(conf: SparkConf): Boolean = { - conf.getBoolean(CometConf.COMET_ENABLED.key, true) && - conf.getBoolean(CometConf.COMET_EXEC_ENABLED.key, false) + conf.getBoolean(CometConf.COMET_ENABLED.key, true) && ( + conf.getBoolean(CometConf.COMET_EXEC_SHUFFLE_ENABLED.key, false) || + conf.getBoolean(CometConf.COMET_EXEC_ENABLED.key, false) + ) } } diff --git a/spark/src/main/scala/org/apache/spark/shuffle/sort/RowPartition.scala b/spark/src/main/scala/org/apache/spark/shuffle/sort/RowPartition.scala new file mode 100644 index 000000000..873e422fb --- /dev/null +++ b/spark/src/main/scala/org/apache/spark/shuffle/sort/RowPartition.scala @@ -0,0 +1,42 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.spark.shuffle.sort + +import scala.collection.mutable.ArrayBuffer + +class RowPartition(initialSize: Int) { + private val rowAddresses: ArrayBuffer[Long] = new ArrayBuffer[Long](initialSize) + private val rowSizes: ArrayBuffer[Int] = new ArrayBuffer[Int](initialSize) + + def addRow(addr: Long, size: Int): Unit = { + rowAddresses += addr + rowSizes += size + } + + def getNumRows: Int = rowAddresses.size + + def getRowAddresses: Array[Long] = rowAddresses.toArray + def getRowSizes: Array[Int] = rowSizes.toArray + + def reset(): Unit = { + rowAddresses.clear() + rowSizes.clear() + } +} diff --git a/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometShuffleExchangeExec.scala b/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometShuffleExchangeExec.scala new file mode 100644 index 000000000..47e6dc70c --- /dev/null +++ b/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometShuffleExchangeExec.scala @@ -0,0 +1,381 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.spark.sql.comet.execution.shuffle + +import java.nio.{ByteBuffer, ByteOrder} +import java.nio.file.{Files, Paths} + +import scala.collection.JavaConverters.asJavaIterableConverter +import scala.concurrent.Future + +import org.apache.spark._ +import org.apache.spark.rdd.{MapPartitionsRDD, RDD} +import org.apache.spark.scheduler.MapStatus +import org.apache.spark.serializer.Serializer +import org.apache.spark.shuffle.{IndexShuffleBlockResolver, ShuffleWriteMetricsReporter, ShuffleWriteProcessor} +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.Attribute +import org.apache.spark.sql.catalyst.plans.logical.Statistics +import org.apache.spark.sql.catalyst.plans.physical._ +import org.apache.spark.sql.comet.{CometExec, CometMetricNode, CometPlan} +import org.apache.spark.sql.execution._ +import org.apache.spark.sql.execution.exchange.{ENSURE_REQUIREMENTS, ShuffleExchangeLike, ShuffleOrigin} +import org.apache.spark.sql.execution.exchange.ShuffleExchangeExec +import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics, SQLShuffleReadMetricsReporter, SQLShuffleWriteMetricsReporter} +import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.vectorized.ColumnarBatch + +import org.apache.comet.serde.{OperatorOuterClass, PartitioningOuterClass, QueryPlanSerde} +import org.apache.comet.serde.OperatorOuterClass.Operator +import org.apache.comet.serde.QueryPlanSerde.serializeDataType +import org.apache.comet.shims.ShimCometShuffleExchangeExec + +/** + * Performs a shuffle that will result in the desired partitioning. + */ +case class CometShuffleExchangeExec( + override val outputPartitioning: Partitioning, + child: SparkPlan, + shuffleOrigin: ShuffleOrigin = ENSURE_REQUIREMENTS, + shuffleType: ShuffleType = CometNativeShuffle, + advisoryPartitionSize: Option[Long] = None) + extends ShuffleExchangeLike + with CometPlan { + + private lazy val writeMetrics = + SQLShuffleWriteMetricsReporter.createShuffleWriteMetrics(sparkContext) + private[sql] lazy val readMetrics = + SQLShuffleReadMetricsReporter.createShuffleReadMetrics(sparkContext) + override lazy val metrics: Map[String, SQLMetric] = Map( + "dataSize" -> SQLMetrics.createSizeMetric(sparkContext, "data size"), + "shuffleReadElapsedCompute" -> + SQLMetrics.createNanoTimingMetric(sparkContext, "shuffle read elapsed compute at native"), + "numPartitions" -> SQLMetrics.createMetric( + sparkContext, + "number of partitions")) ++ readMetrics ++ writeMetrics + + override def nodeName: String = if (shuffleType == CometNativeShuffle) { + "CometExchange" + } else { + "CometColumnarExchange" + } + + private lazy val serializer: Serializer = + new UnsafeRowSerializer(child.output.size, longMetric("dataSize")) + + @transient lazy val inputRDD: RDD[_] = if (shuffleType == CometNativeShuffle) { + // CometNativeShuffle assumes that the input plan is Comet plan. + child.executeColumnar() + } else if (shuffleType == CometColumnarShuffle) { + // CometColumnarShuffle assumes that the input plan is row-based plan from Spark. + // One exception is that the input plan is CometScanExec which manually converts + // ColumnarBatch to InternalRow in its doExecute(). + child.execute() + } else { + throw new UnsupportedOperationException( + s"Unsupported shuffle type: ${shuffleType.getClass.getName}") + } + + // 'mapOutputStatisticsFuture' is only needed when enable AQE. + @transient + override lazy val mapOutputStatisticsFuture: Future[MapOutputStatistics] = { + if (inputRDD.getNumPartitions == 0) { + Future.successful(null) + } else { + sparkContext.submitMapStage(shuffleDependency) + } + } + + override def numMappers: Int = shuffleDependency.rdd.getNumPartitions + + override def numPartitions: Int = shuffleDependency.partitioner.numPartitions + + override def getShuffleRDD(partitionSpecs: Array[ShufflePartitionSpec]): RDD[_] = + new CometShuffledBatchRDD(shuffleDependency, readMetrics, partitionSpecs) + + override def runtimeStatistics: Statistics = { + val dataSize = metrics("dataSize").value + val rowCount = metrics(SQLShuffleWriteMetricsReporter.SHUFFLE_RECORDS_WRITTEN).value + Statistics(dataSize, Some(rowCount)) + } + + /** + * A [[ShuffleDependency]] that will partition rows of its child based on the partitioning + * scheme defined in `newPartitioning`. Those partitions of the returned ShuffleDependency will + * be the input of shuffle. + */ + @transient + lazy val shuffleDependency: ShuffleDependency[Int, _, _] = + if (shuffleType == CometNativeShuffle) { + val dep = CometShuffleExchangeExec.prepareShuffleDependency( + inputRDD.asInstanceOf[RDD[ColumnarBatch]], + child.output, + outputPartitioning, + serializer, + metrics) + metrics("numPartitions").set(dep.partitioner.numPartitions) + val executionId = sparkContext.getLocalProperty(SQLExecution.EXECUTION_ID_KEY) + SQLMetrics.postDriverMetricUpdates( + sparkContext, + executionId, + metrics("numPartitions") :: Nil) + dep + } else if (shuffleType == CometColumnarShuffle) { + val dep = CometShuffleExchangeExec.prepareJVMShuffleDependency( + inputRDD.asInstanceOf[RDD[InternalRow]], + child.output, + outputPartitioning, + serializer, + metrics) + metrics("numPartitions").set(dep.partitioner.numPartitions) + val executionId = sparkContext.getLocalProperty(SQLExecution.EXECUTION_ID_KEY) + SQLMetrics.postDriverMetricUpdates( + sparkContext, + executionId, + metrics("numPartitions") :: Nil) + dep + } else { + throw new UnsupportedOperationException( + s"Unsupported shuffle type: ${shuffleType.getClass.getName}") + } + + protected override def doExecute(): RDD[InternalRow] = { + throw new UnsupportedOperationException( + "CometShuffleExchangeExec.doExecute should not be executed.") + } + + /** + * Comet supports columnar execution. + */ + override val supportsColumnar: Boolean = true + + /** + * Caches the created CometShuffledBatchRDD so we can reuse that. + */ + private var cachedShuffleRDD: CometShuffledBatchRDD = null + + /** + * Comet returns RDD[ColumnarBatch] for columnar execution. + */ + protected override def doExecuteColumnar(): RDD[ColumnarBatch] = { + // Returns the same CometShuffledBatchRDD if this plan is used by multiple plans. + if (cachedShuffleRDD == null) { + cachedShuffleRDD = new CometShuffledBatchRDD(shuffleDependency, readMetrics) + } + cachedShuffleRDD + } + + override protected def withNewChildInternal(newChild: SparkPlan): CometShuffleExchangeExec = + copy(child = newChild) +} + +object CometShuffleExchangeExec extends ShimCometShuffleExchangeExec { + def prepareShuffleDependency( + rdd: RDD[ColumnarBatch], + outputAttributes: Seq[Attribute], + outputPartitioning: Partitioning, + serializer: Serializer, + metrics: Map[String, SQLMetric]): ShuffleDependency[Int, ColumnarBatch, ColumnarBatch] = { + val dependency = new CometShuffleDependency[Int, ColumnarBatch, ColumnarBatch]( + rdd.map( + (0, _) + ), // adding fake partitionId that is always 0 because ShuffleDependency requires it + serializer = serializer, + shuffleWriterProcessor = + new CometShuffleWriteProcessor(outputPartitioning, outputAttributes, metrics), + shuffleType = CometNativeShuffle, + partitioner = new Partitioner { + override def numPartitions: Int = outputPartitioning.numPartitions + override def getPartition(key: Any): Int = key.asInstanceOf[Int] + }) + dependency + } + + /** + * Returns a [[ShuffleDependency]] that will partition rows of its child based on the + * partitioning scheme defined in `newPartitioning`. Those partitions of the returned + * ShuffleDependency will be the input of shuffle. + */ + def prepareJVMShuffleDependency( + rdd: RDD[InternalRow], + outputAttributes: Seq[Attribute], + newPartitioning: Partitioning, + serializer: Serializer, + writeMetrics: Map[String, SQLMetric]): ShuffleDependency[Int, InternalRow, InternalRow] = { + val sparkShuffleDep = ShuffleExchangeExec.prepareShuffleDependency( + rdd, + outputAttributes, + newPartitioning, + serializer, + writeMetrics) + + val dependency = + new CometShuffleDependency[Int, InternalRow, InternalRow]( + sparkShuffleDep.rdd, + sparkShuffleDep.partitioner, + sparkShuffleDep.serializer, + shuffleWriterProcessor = sparkShuffleDep.shuffleWriterProcessor, + shuffleType = CometColumnarShuffle, + schema = Some(StructType.fromAttributes(outputAttributes))) + dependency + } +} + +/** + * A [[ShuffleWriteProcessor]] that will delegate shuffle write to native shuffle. + * @param metrics + * metrics to report + */ +class CometShuffleWriteProcessor( + outputPartitioning: Partitioning, + outputAttributes: Seq[Attribute], + metrics: Map[String, SQLMetric]) + extends ShuffleWriteProcessor { + + private val OFFSET_LENGTH = 8 + + override protected def createMetricsReporter( + context: TaskContext): ShuffleWriteMetricsReporter = { + new SQLShuffleWriteMetricsReporter(context.taskMetrics().shuffleWriteMetrics, metrics) + } + + override def write( + rdd: RDD[_], + dep: ShuffleDependency[_, _, _], + mapId: Long, + context: TaskContext, + partition: Partition): MapStatus = { + val shuffleBlockResolver = + SparkEnv.get.shuffleManager.shuffleBlockResolver.asInstanceOf[IndexShuffleBlockResolver] + val dataFile = shuffleBlockResolver.getDataFile(dep.shuffleId, mapId) + val indexFile = shuffleBlockResolver.getIndexFile(dep.shuffleId, mapId) + val tempDataFilename = dataFile.getPath.replace(".data", ".data.tmp") + val tempIndexFilename = indexFile.getPath.replace(".index", ".index.tmp") + val tempDataFilePath = Paths.get(tempDataFilename) + val tempIndexFilePath = Paths.get(tempIndexFilename) + + // Getting rid of the fake partitionId + val cometRDD = + rdd.asInstanceOf[MapPartitionsRDD[_, _]].prev.asInstanceOf[RDD[ColumnarBatch]] + + // Call native shuffle write + val nativePlan = getNativePlan(tempDataFilename, tempIndexFilename) + + // Maps native metrics to SQL metrics + val nativeSQLMetrics = Map( + "output_rows" -> metrics(SQLShuffleWriteMetricsReporter.SHUFFLE_RECORDS_WRITTEN), + "elapsed_compute" -> metrics("shuffleReadElapsedCompute")) + val nativeMetrics = CometMetricNode(nativeSQLMetrics) + + val rawIter = cometRDD.iterator(partition, context) + val cometIter = CometExec.getCometIterator(Seq(rawIter), nativePlan, nativeMetrics) + + while (cometIter.hasNext) { + cometIter.next() + } + + // get partition lengths from shuffle write output index file + var offset = 0L + val partitionLengths = Files + .readAllBytes(tempIndexFilePath) + .grouped(OFFSET_LENGTH) + .drop(1) // first partition offset is always 0 + .map(indexBytes => { + val partitionOffset = + ByteBuffer.wrap(indexBytes).order(ByteOrder.LITTLE_ENDIAN).getLong + val partitionLength = partitionOffset - offset + offset = partitionOffset + partitionLength + }) + .toArray + + // Update Spark metrics from native metrics + metrics("dataSize") += Files.size(tempDataFilePath) + + // commit + shuffleBlockResolver.writeMetadataFileAndCommit( + dep.shuffleId, + mapId, + partitionLengths, + Array.empty, // TODO: add checksums + tempDataFilePath.toFile) + MapStatus.apply(SparkEnv.get.blockManager.shuffleServerId, partitionLengths, mapId) + } + + def getNativePlan(dataFile: String, indexFile: String): Operator = { + val scanBuilder = OperatorOuterClass.Scan.newBuilder() + val opBuilder = OperatorOuterClass.Operator.newBuilder() + + val scanTypes = outputAttributes.flatten { attr => + serializeDataType(attr.dataType) + } + + if (scanTypes.length == outputAttributes.length) { + scanBuilder.addAllFields(scanTypes.asJava) + + val shuffleWriterBuilder = OperatorOuterClass.ShuffleWriter.newBuilder() + shuffleWriterBuilder.setOutputDataFile(dataFile) + shuffleWriterBuilder.setOutputIndexFile(indexFile) + + outputPartitioning match { + case _: HashPartitioning => + val hashPartitioning = outputPartitioning.asInstanceOf[HashPartitioning] + + val partitioning = PartitioningOuterClass.HashRepartition.newBuilder() + partitioning.setNumPartitions(outputPartitioning.numPartitions) + + val partitionExprs = hashPartitioning.expressions + .flatMap(e => QueryPlanSerde.exprToProto(e, outputAttributes)) + + if (partitionExprs.length != hashPartitioning.expressions.length) { + throw new UnsupportedOperationException( + s"Partitioning $hashPartitioning is not supported.") + } + + partitioning.addAllHashExpression(partitionExprs.asJava) + + val partitioningBuilder = PartitioningOuterClass.Partitioning.newBuilder() + shuffleWriterBuilder.setPartitioning( + partitioningBuilder.setHashPartition(partitioning).build()) + + case SinglePartition => + val partitioning = PartitioningOuterClass.SinglePartition.newBuilder() + + val partitioningBuilder = PartitioningOuterClass.Partitioning.newBuilder() + shuffleWriterBuilder.setPartitioning( + partitioningBuilder.setSinglePartition(partitioning).build()) + + case _ => + throw new UnsupportedOperationException( + s"Partitioning $outputPartitioning is not supported.") + } + + val shuffleWriterOpBuilder = OperatorOuterClass.Operator.newBuilder() + shuffleWriterOpBuilder + .setShuffleWriter(shuffleWriterBuilder) + .addChildren(opBuilder.setScan(scanBuilder).build()) + .build() + } else { + // There are unsupported scan type + throw new UnsupportedOperationException( + s"$outputAttributes contains unsupported data types for CometShuffleExchangeExec.") + } + } +} diff --git a/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometShuffleManager.scala b/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometShuffleManager.scala new file mode 100644 index 000000000..51e6df578 --- /dev/null +++ b/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometShuffleManager.scala @@ -0,0 +1,280 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.spark.sql.comet.execution.shuffle + +import java.util.concurrent.ConcurrentHashMap + +import scala.collection.JavaConverters._ + +import org.apache.spark.ShuffleDependency +import org.apache.spark.SparkConf +import org.apache.spark.SparkEnv +import org.apache.spark.TaskContext +import org.apache.spark.internal.{config, Logging} +import org.apache.spark.internal.config.IO_COMPRESSION_CODEC +import org.apache.spark.io.CompressionCodec +import org.apache.spark.shuffle._ +import org.apache.spark.shuffle.api.ShuffleExecutorComponents +import org.apache.spark.shuffle.sort.{BypassMergeSortShuffleHandle, SerializedShuffleHandle, SortShuffleManager, SortShuffleWriter} +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.util.collection.OpenHashSet + +import org.apache.comet.CometConf + +/** + * A [[ShuffleManager]] that uses Arrow format to shuffle data. + */ +class CometShuffleManager(conf: SparkConf) extends ShuffleManager with Logging { + + import CometShuffleManager._ + import SortShuffleManager._ + + if (!conf.getBoolean("spark.shuffle.spill", true)) { + logWarning( + "spark.shuffle.spill was set to false, but this configuration is ignored as of Spark 1.6+." + + " Shuffle will continue to spill to disk when necessary.") + } + + private val sortShuffleManager = new SortShuffleManager(conf); + + /** + * A mapping from shuffle ids to the task ids of mappers producing output for those shuffles. + */ + private[this] val taskIdMapsForShuffle = new ConcurrentHashMap[Int, OpenHashSet[Long]]() + + private lazy val shuffleExecutorComponents = loadShuffleExecutorComponents(conf) + + override val shuffleBlockResolver = new IndexShuffleBlockResolver(conf) + + /** + * (override) Obtains a [[ShuffleHandle]] to pass to tasks. + */ + def registerShuffle[K, V, C]( + shuffleId: Int, + dependency: ShuffleDependency[K, V, C]): ShuffleHandle = { + if (dependency.isInstanceOf[CometShuffleDependency[_, _, _]]) { + // Comet shuffle dependency, which comes from `CometShuffleExchangeExec`. + if (shouldBypassMergeSort(conf, dependency) || + !SortShuffleManager.canUseSerializedShuffle(dependency)) { + new CometBypassMergeSortShuffleHandle( + shuffleId, + dependency.asInstanceOf[ShuffleDependency[K, V, V]]) + } else { + new CometSerializedShuffleHandle( + shuffleId, + dependency.asInstanceOf[ShuffleDependency[K, V, V]]) + } + } else { + // It is a Spark shuffle dependency, so we use Spark Sort Shuffle Manager. + if (SortShuffleWriter.shouldBypassMergeSort(conf, dependency)) { + // If there are fewer than spark.shuffle.sort.bypassMergeThreshold partitions and we don't + // need map-side aggregation, then write numPartitions files directly and just concatenate + // them at the end. This avoids doing serialization and deserialization twice to merge + // together the spilled files, which would happen with the normal code path. The downside is + // having multiple files open at a time and thus more memory allocated to buffers. + new BypassMergeSortShuffleHandle[K, V]( + shuffleId, + dependency.asInstanceOf[ShuffleDependency[K, V, V]]) + } else if (SortShuffleManager.canUseSerializedShuffle(dependency)) { + // Otherwise, try to buffer map outputs in a serialized form, since this is more efficient: + new SerializedShuffleHandle[K, V]( + shuffleId, + dependency.asInstanceOf[ShuffleDependency[K, V, V]]) + } else { + // Otherwise, buffer map outputs in a deserialized form: + new BaseShuffleHandle(shuffleId, dependency) + } + } + } + + override def getReader[K, C]( + handle: ShuffleHandle, + startMapIndex: Int, + endMapIndex: Int, + startPartition: Int, + endPartition: Int, + context: TaskContext, + metrics: ShuffleReadMetricsReporter): ShuffleReader[K, C] = { + val baseShuffleHandle = handle.asInstanceOf[BaseShuffleHandle[K, _, C]] + val (blocksByAddress, canEnableBatchFetch) = + if (baseShuffleHandle.dependency.shuffleMergeEnabled) { + val res = SparkEnv.get.mapOutputTracker.getPushBasedShuffleMapSizesByExecutorId( + handle.shuffleId, + startMapIndex, + endMapIndex, + startPartition, + endPartition) + (res.iter, res.enableBatchFetch) + } else { + val address = SparkEnv.get.mapOutputTracker.getMapSizesByExecutorId( + handle.shuffleId, + startMapIndex, + endMapIndex, + startPartition, + endPartition) + (address, true) + } + + if (handle.isInstanceOf[CometBypassMergeSortShuffleHandle[_, _]] || + handle.isInstanceOf[CometSerializedShuffleHandle[_, _]]) { + new CometBlockStoreShuffleReader( + handle.asInstanceOf[BaseShuffleHandle[K, _, C]], + blocksByAddress, + context, + metrics, + shouldBatchFetch = + canEnableBatchFetch && canUseBatchFetch(startPartition, endPartition, context)) + } else { + // It is a Spark shuffle dependency, so we use Spark Sort Shuffle Reader. + sortShuffleManager.getReader( + handle, + startMapIndex, + endMapIndex, + startPartition, + endPartition, + context, + metrics) + } + } + + /** Get a writer for a given partition. Called on executors by map tasks. */ + override def getWriter[K, V]( + handle: ShuffleHandle, + mapId: Long, + context: TaskContext, + metrics: ShuffleWriteMetricsReporter): ShuffleWriter[K, V] = { + val mapTaskIds = + taskIdMapsForShuffle.computeIfAbsent(handle.shuffleId, _ => new OpenHashSet[Long](16)) + mapTaskIds.synchronized { + mapTaskIds.add(context.taskAttemptId()) + } + val env = SparkEnv.get + handle match { + case bypassMergeSortHandle: CometBypassMergeSortShuffleHandle[K @unchecked, V @unchecked] => + new CometBypassMergeSortShuffleWriter( + env.blockManager, + context.taskMemoryManager(), + context, + bypassMergeSortHandle, + mapId, + env.conf, + metrics, + shuffleExecutorComponents) + case unsafeShuffleHandle: CometSerializedShuffleHandle[K @unchecked, V @unchecked] => + new CometUnsafeShuffleWriter( + env.blockManager, + context.taskMemoryManager(), + unsafeShuffleHandle, + mapId, + context, + env.conf, + metrics, + shuffleExecutorComponents) + case _ => + // It is a Spark shuffle dependency, so we use Spark Sort Shuffle Writer. + sortShuffleManager.getWriter(handle, mapId, context, metrics) + } + } + + /** Remove a shuffle's metadata from the ShuffleManager. */ + override def unregisterShuffle(shuffleId: Int): Boolean = { + Option(taskIdMapsForShuffle.remove(shuffleId)).foreach { mapTaskIds => + mapTaskIds.iterator.foreach { mapTaskId => + shuffleBlockResolver.removeDataByMap(shuffleId, mapTaskId) + } + } + true + } + + /** Shut down this ShuffleManager. */ + override def stop(): Unit = { + shuffleBlockResolver.stop() + } +} + +object CometShuffleManager extends Logging { + + /** + * Loads executor components for shuffle data IO. + */ + private def loadShuffleExecutorComponents(conf: SparkConf): ShuffleExecutorComponents = { + val executorComponents = ShuffleDataIOUtils.loadShuffleDataIO(conf).executor() + val extraConfigs = conf.getAllWithPrefix(ShuffleDataIOUtils.SHUFFLE_SPARK_CONF_PREFIX).toMap + executorComponents.initializeExecutor( + conf.getAppId, + SparkEnv.get.executorId, + extraConfigs.asJava) + executorComponents + } + + lazy val compressionCodecForShuffling: CompressionCodec = { + val sparkConf = SparkEnv.get.conf + val codecName = CometConf.COMET_EXEC_SHUFFLE_CODEC.get(SQLConf.get) + + // only zstd compression is supported at the moment + if (codecName != "zstd") { + logWarning( + s"Overriding config ${IO_COMPRESSION_CODEC}=${codecName} in shuffling, force using zstd") + } + CompressionCodec.createCodec(sparkConf, "zstd") + } + + def shouldBypassMergeSort(conf: SparkConf, dep: ShuffleDependency[_, _, _]): Boolean = { + // We cannot bypass sorting if we need to do map-side aggregation. + if (dep.mapSideCombine) { + false + } else { + // Condition from Spark: + // Bypass merge sort if we have fewer than `spark.shuffle.sort.bypassMergeThreshold` + val partitionCond = SortShuffleWriter.shouldBypassMergeSort(conf, dep) + + // Bypass merge sort if we have partition * cores fewer than + // `spark.comet.columnar.shuffle.async.max.thread.num` + val executorCores = conf.get(config.EXECUTOR_CORES) + val maxThreads = CometConf.COMET_EXEC_SHUFFLE_ASYNC_MAX_THREAD_NUM.get(SQLConf.get) + val threadCond = dep.partitioner.numPartitions * executorCores <= maxThreads + + // Comet columnar shuffle buffers rows in memory. If too many cores are used with + // relatively high number of partitions, it may cause OOM when initializing the + // hash-based shuffle writers at beginning of the task. For example, 10 cores + // with 100 partitions will allocates 1000 writers. Sort-based shuffle doesn't have + // this issue because it only allocates one writer per task. + partitionCond && threadCond + } + } +} + +/** + * Subclass of [[BaseShuffleHandle]], used to identify when we've chosen to use the bypass merge + * sort shuffle path. + */ +private[spark] class CometBypassMergeSortShuffleHandle[K, V]( + shuffleId: Int, + dependency: ShuffleDependency[K, V, V]) + extends BaseShuffleHandle(shuffleId, dependency) {} + +/** + * Subclass of [[BaseShuffleHandle]], used to identify when we've chosen to use the serialized + * shuffle. + */ +private[spark] class CometSerializedShuffleHandle[K, V]( + shuffleId: Int, + dependency: ShuffleDependency[K, V, V]) + extends BaseShuffleHandle(shuffleId, dependency) {} diff --git a/spark/src/test/scala/org/apache/comet/CometSparkSessionExtensionsSuite.scala b/spark/src/test/scala/org/apache/comet/CometSparkSessionExtensionsSuite.scala index 2c818b97a..fdcd307be 100644 --- a/spark/src/test/scala/org/apache/comet/CometSparkSessionExtensionsSuite.scala +++ b/spark/src/test/scala/org/apache/comet/CometSparkSessionExtensionsSuite.scala @@ -142,4 +142,42 @@ class CometSparkSessionExtensionsSuite extends CometTestBase { CometSparkSessionExtensions .getCometMemoryOverhead(sparkConf) == getBytesFromMib(1024 * 10)) } + + test("Comet shuffle memory factor") { + val conf = new SparkConf() + + val sqlConf = new SQLConf + sqlConf.setConfString(CometConf.COMET_COLUMNAR_SHUFFLE_MEMORY_FACTOR.key, "0.2") + + // Minimum Comet memory overhead is 384MB + assert( + CometSparkSessionExtensions.getCometShuffleMemorySize(conf, sqlConf) == + getBytesFromMib((384 * 0.2).toLong)) + + conf.set(CometConf.COMET_MEMORY_OVERHEAD_FACTOR.key, "0.5") + assert( + CometSparkSessionExtensions.getCometShuffleMemorySize(conf, sqlConf) == + getBytesFromMib((1024 * 0.5 * 0.2).toLong)) + } + + test("Comet shuffle memory") { + val conf = new SparkConf() + val sqlConf = new SQLConf + conf.set(CometConf.COMET_MEMORY_OVERHEAD.key, "1g") + sqlConf.setConfString(CometConf.COMET_COLUMNAR_SHUFFLE_MEMORY_SIZE.key, "512m") + + assert( + CometSparkSessionExtensions + .getCometShuffleMemorySize(conf, sqlConf) == getBytesFromMib(512)) + } + + test("Comet shuffle memory cannot be larger than Comet memory overhead") { + val conf = new SparkConf() + val sqlConf = new SQLConf + conf.set(CometConf.COMET_MEMORY_OVERHEAD.key, "1g") + sqlConf.setConfString(CometConf.COMET_COLUMNAR_SHUFFLE_MEMORY_SIZE.key, "10g") + assert( + CometSparkSessionExtensions + .getCometShuffleMemorySize(conf, sqlConf) == getBytesFromMib(1024)) + } } diff --git a/spark/src/test/scala/org/apache/comet/exec/CometAggregateSuite.scala b/spark/src/test/scala/org/apache/comet/exec/CometAggregateSuite.scala index f0cc96a6b..da096e56e 100644 --- a/spark/src/test/scala/org/apache/comet/exec/CometAggregateSuite.scala +++ b/spark/src/test/scala/org/apache/comet/exec/CometAggregateSuite.scala @@ -39,7 +39,10 @@ class CometAggregateSuite extends CometTestBase with AdaptiveSparkPlanHelper { test("Fix NPE in partial decimal sum") { val table = "tbl" withTable(table) { - withSQLConf(CometConf.COMET_ENABLED.key -> "true") { + withSQLConf( + CometConf.COMET_ENABLED.key -> "true", + CometConf.COMET_EXEC_SHUFFLE_ENABLED.key -> "false", + CometConf.COMET_COLUMNAR_SHUFFLE_ENABLED.key -> "false") { withTable(table) { sql(s"CREATE TABLE $table(col DECIMAL(5, 2)) USING PARQUET") sql(s"INSERT INTO TABLE $table VALUES (CAST(12345.01 AS DECIMAL(5, 2)))") @@ -50,6 +53,22 @@ class CometAggregateSuite extends CometTestBase with AdaptiveSparkPlanHelper { } } + test("fix: Decimal Average should not enable native final aggregation") { + withSQLConf(CometConf.COMET_EXEC_SHUFFLE_ENABLED.key -> "true") { + Seq(true, false).foreach { dictionaryEnabled => + withTempDir { dir => + val path = new Path(dir.toURI.toString, "test") + makeParquetFile(path, 1000, 10, dictionaryEnabled) + withParquetTable(path.toUri.toString, "tbl") { + checkSparkAnswer("SELECT _g1, AVG(_7) FROM tbl GROUP BY _g1") + checkSparkAnswer("SELECT _g1, AVG(_8) FROM tbl GROUP BY _g1") + checkSparkAnswer("SELECT _g1, AVG(_9) FROM tbl GROUP BY _g1") + } + } + } + } + } + test("trivial case") { Seq(true, false).foreach { dictionaryEnabled => withParquetTable((0 until 5).map(i => (i, i)), "tbl", dictionaryEnabled) { @@ -119,14 +138,23 @@ class CometAggregateSuite extends CometTestBase with AdaptiveSparkPlanHelper { } test("group-by on variable length types") { - Seq(true, false).foreach { dictionaryEnabled => - withParquetTable((0 until 100).map(i => (i, (i % 10).toString)), "tbl", dictionaryEnabled) { - val n = 1 - checkSparkAnswerAndNumOfAggregates("SELECT _2, SUM(_1) FROM tbl GROUP BY _2", n) - checkSparkAnswerAndNumOfAggregates("SELECT _2, COUNT(_1) FROM tbl GROUP BY _2", n) - checkSparkAnswerAndNumOfAggregates("SELECT _2, MIN(_1) FROM tbl GROUP BY _2", n) - checkSparkAnswerAndNumOfAggregates("SELECT _2, MAX(_1) FROM tbl GROUP BY _2", n) - checkSparkAnswerAndNumOfAggregates("SELECT _2, AVG(_1) FROM tbl GROUP BY _2", n) + Seq(true, false).foreach { nativeShuffleEnabled => + Seq(true, false).foreach { dictionaryEnabled => + withSQLConf( + CometConf.COMET_EXEC_SHUFFLE_ENABLED.key -> nativeShuffleEnabled.toString, + CometConf.COMET_COLUMNAR_SHUFFLE_ENABLED.key -> "false") { + withParquetTable( + (0 until 100).map(i => (i, (i % 10).toString)), + "tbl", + dictionaryEnabled) { + val n = if (nativeShuffleEnabled) 2 else 1 + checkSparkAnswerAndNumOfAggregates("SELECT _2, SUM(_1) FROM tbl GROUP BY _2", n) + checkSparkAnswerAndNumOfAggregates("SELECT _2, COUNT(_1) FROM tbl GROUP BY _2", n) + checkSparkAnswerAndNumOfAggregates("SELECT _2, MIN(_1) FROM tbl GROUP BY _2", n) + checkSparkAnswerAndNumOfAggregates("SELECT _2, MAX(_1) FROM tbl GROUP BY _2", n) + checkSparkAnswerAndNumOfAggregates("SELECT _2, AVG(_1) FROM tbl GROUP BY _2", n) + } + } } } } @@ -292,30 +320,36 @@ class CometAggregateSuite extends CometTestBase with AdaptiveSparkPlanHelper { test("SUM decimal with DF") { Seq(true, false).foreach { dictionaryEnabled => - withTempDir { dir => - val path = new Path(dir.toURI.toString, "test") - makeParquetFile(path, 1000, 20, dictionaryEnabled) - withParquetTable(path.toUri.toString, "tbl") { - val expectedNumOfCometAggregates = 1 - - checkSparkAnswerAndNumOfAggregates( - "SELECT _g2, SUM(_7) FROM tbl GROUP BY _g2", - expectedNumOfCometAggregates) - checkSparkAnswerAndNumOfAggregates( - "SELECT _g3, SUM(_8) FROM tbl GROUP BY _g3", - expectedNumOfCometAggregates) - checkSparkAnswerAndNumOfAggregates( - "SELECT _g4, SUM(_9) FROM tbl GROUP BY _g4", - expectedNumOfCometAggregates) - checkSparkAnswerAndNumOfAggregates( - "SELECT SUM(_7) FROM tbl", - expectedNumOfCometAggregates) - checkSparkAnswerAndNumOfAggregates( - "SELECT SUM(_8) FROM tbl", - expectedNumOfCometAggregates) - checkSparkAnswerAndNumOfAggregates( - "SELECT SUM(_9) FROM tbl", - expectedNumOfCometAggregates) + Seq(true, false).foreach { nativeShuffleEnabled => + withSQLConf( + CometConf.COMET_EXEC_SHUFFLE_ENABLED.key -> nativeShuffleEnabled.toString, + CometConf.COMET_COLUMNAR_SHUFFLE_ENABLED.key -> "false") { + withTempDir { dir => + val path = new Path(dir.toURI.toString, "test") + makeParquetFile(path, 1000, 20, dictionaryEnabled) + withParquetTable(path.toUri.toString, "tbl") { + val expectedNumOfCometAggregates = if (nativeShuffleEnabled) 2 else 1 + + checkSparkAnswerAndNumOfAggregates( + "SELECT _g2, SUM(_7) FROM tbl GROUP BY _g2", + expectedNumOfCometAggregates) + checkSparkAnswerAndNumOfAggregates( + "SELECT _g3, SUM(_8) FROM tbl GROUP BY _g3", + expectedNumOfCometAggregates) + checkSparkAnswerAndNumOfAggregates( + "SELECT _g4, SUM(_9) FROM tbl GROUP BY _g4", + expectedNumOfCometAggregates) + checkSparkAnswerAndNumOfAggregates( + "SELECT SUM(_7) FROM tbl", + expectedNumOfCometAggregates) + checkSparkAnswerAndNumOfAggregates( + "SELECT SUM(_8) FROM tbl", + expectedNumOfCometAggregates) + checkSparkAnswerAndNumOfAggregates( + "SELECT SUM(_9) FROM tbl", + expectedNumOfCometAggregates) + } + } } } } @@ -442,98 +476,117 @@ class CometAggregateSuite extends CometTestBase with AdaptiveSparkPlanHelper { } test("test final count") { - Seq(false, true).foreach { dictionaryEnabled => - val n = 1 - withParquetTable((0 until 5).map(i => (i, i % 2)), "tbl", dictionaryEnabled) { - checkSparkAnswerAndNumOfAggregates("SELECT _2, COUNT(_1) FROM tbl GROUP BY _2", n) - checkSparkAnswerAndNumOfAggregates("select count(_1) from tbl", n) - checkSparkAnswerAndNumOfAggregates( - "SELECT _2, COUNT(_1), SUM(_1) FROM tbl GROUP BY _2", - n) - checkSparkAnswerAndNumOfAggregates("SELECT COUNT(_1), COUNT(_2) FROM tbl", n) + withSQLConf( + CometConf.COMET_EXEC_SHUFFLE_ENABLED.key -> "true", + CometConf.COMET_COLUMNAR_SHUFFLE_ENABLED.key -> "false") { + Seq(false, true).foreach { dictionaryEnabled => + withParquetTable((0 until 5).map(i => (i, i % 2)), "tbl", dictionaryEnabled) { + checkSparkAnswerAndNumOfAggregates("SELECT _2, COUNT(_1) FROM tbl GROUP BY _2", 2) + checkSparkAnswerAndNumOfAggregates("select count(_1) from tbl", 2) + checkSparkAnswerAndNumOfAggregates( + "SELECT _2, COUNT(_1), SUM(_1) FROM tbl GROUP BY _2", + 2) + checkSparkAnswerAndNumOfAggregates("SELECT COUNT(_1), COUNT(_2) FROM tbl", 2) + } } } } test("test final min/max") { - Seq(true, false).foreach { dictionaryEnabled => - withParquetTable((0 until 5).map(i => (i, i % 2)), "tbl", dictionaryEnabled) { - val n = 1 - checkSparkAnswerAndNumOfAggregates( - "SELECT _2, MIN(_1), MAX(_1), COUNT(_1) FROM tbl GROUP BY _2", - n) - checkSparkAnswerAndNumOfAggregates("SELECT MIN(_1), MAX(_1), COUNT(_1) FROM tbl", 1) - checkSparkAnswerAndNumOfAggregates( - "SELECT _2, MIN(_1), MAX(_1), COUNT(_1), SUM(_1) FROM tbl GROUP BY _2", - n) - checkSparkAnswerAndNumOfAggregates( - "SELECT MIN(_1), MIN(_2), MAX(_1), MAX(_2), COUNT(_1), COUNT(_2) FROM tbl", - n) + withSQLConf( + CometConf.COMET_EXEC_SHUFFLE_ENABLED.key -> "true", + CometConf.COMET_COLUMNAR_SHUFFLE_ENABLED.key -> "false") { + Seq(true, false).foreach { dictionaryEnabled => + withParquetTable((0 until 5).map(i => (i, i % 2)), "tbl", dictionaryEnabled) { + checkSparkAnswerAndNumOfAggregates( + "SELECT _2, MIN(_1), MAX(_1), COUNT(_1) FROM tbl GROUP BY _2", + 2) + checkSparkAnswerAndNumOfAggregates("SELECT MIN(_1), MAX(_1), COUNT(_1) FROM tbl", 2) + checkSparkAnswerAndNumOfAggregates( + "SELECT _2, MIN(_1), MAX(_1), COUNT(_1), SUM(_1) FROM tbl GROUP BY _2", + 2) + checkSparkAnswerAndNumOfAggregates( + "SELECT MIN(_1), MIN(_2), MAX(_1), MAX(_2), COUNT(_1), COUNT(_2) FROM tbl", + 2) + } } } } test("test final min/max/count with result expressions") { - Seq(true, false).foreach { dictionaryEnabled => - withParquetTable((0 until 5).map(i => (i, i % 2)), "tbl", dictionaryEnabled) { - val n = 1 - checkSparkAnswerAndNumOfAggregates( - "SELECT _2, MIN(_1) + 2, COUNT(_1) FROM tbl GROUP BY _2", - n) - checkSparkAnswerAndNumOfAggregates("SELECT _2, COUNT(_1) + 2 FROM tbl GROUP BY _2", n) - checkSparkAnswerAndNumOfAggregates("SELECT _2 + 2, COUNT(_1) FROM tbl GROUP BY _2", n) - checkSparkAnswerAndNumOfAggregates("SELECT _2, MIN(_1) + MAX(_1) FROM tbl GROUP BY _2", n) - checkSparkAnswerAndNumOfAggregates("SELECT _2, MIN(_1) + _2 FROM tbl GROUP BY _2", n) - checkSparkAnswerAndNumOfAggregates( - "SELECT _2 + 2, MIN(_1), MAX(_1), COUNT(_1) FROM tbl GROUP BY _2", - n) - checkSparkAnswerAndNumOfAggregates( - "SELECT _2, MIN(_1), MAX(_1) + 2, COUNT(_1) FROM tbl GROUP BY _2", - n) - checkSparkAnswerAndNumOfAggregates("SELECT _2, SUM(_1) + 2 FROM tbl GROUP BY _2", n) - checkSparkAnswerAndNumOfAggregates("SELECT _2 + 2, SUM(_1) FROM tbl GROUP BY _2", n) - checkSparkAnswerAndNumOfAggregates("SELECT _2, SUM(_1 + 1) FROM tbl GROUP BY _2", n) - - // result expression is unsupported by Comet, so only partial aggregation should be used - val df = sql( - "SELECT _2, MIN(_1) + java_method('java.lang.Math', 'random') " + - "FROM tbl GROUP BY _2") - assert(getNumCometHashAggregate(df) == 1) + withSQLConf( + CometConf.COMET_EXEC_SHUFFLE_ENABLED.key -> "true", + CometConf.COMET_COLUMNAR_SHUFFLE_ENABLED.key -> "false") { + Seq(true, false).foreach { dictionaryEnabled => + withParquetTable((0 until 5).map(i => (i, i % 2)), "tbl", dictionaryEnabled) { + checkSparkAnswerAndNumOfAggregates( + "SELECT _2, MIN(_1) + 2, COUNT(_1) FROM tbl GROUP BY _2", + 2) + checkSparkAnswerAndNumOfAggregates("SELECT _2, COUNT(_1) + 2 FROM tbl GROUP BY _2", 2) + checkSparkAnswerAndNumOfAggregates("SELECT _2 + 2, COUNT(_1) FROM tbl GROUP BY _2", 2) + checkSparkAnswerAndNumOfAggregates( + "SELECT _2, MIN(_1) + MAX(_1) FROM tbl GROUP BY _2", + 2) + checkSparkAnswerAndNumOfAggregates("SELECT _2, MIN(_1) + _2 FROM tbl GROUP BY _2", 2) + checkSparkAnswerAndNumOfAggregates( + "SELECT _2 + 2, MIN(_1), MAX(_1), COUNT(_1) FROM tbl GROUP BY _2", + 2) + checkSparkAnswerAndNumOfAggregates( + "SELECT _2, MIN(_1), MAX(_1) + 2, COUNT(_1) FROM tbl GROUP BY _2", + 2) + checkSparkAnswerAndNumOfAggregates("SELECT _2, SUM(_1) + 2 FROM tbl GROUP BY _2", 2) + checkSparkAnswerAndNumOfAggregates("SELECT _2 + 2, SUM(_1) FROM tbl GROUP BY _2", 2) + checkSparkAnswerAndNumOfAggregates("SELECT _2, SUM(_1 + 1) FROM tbl GROUP BY _2", 2) + + // result expression is unsupported by Comet, so only partial aggregation should be used + val df = sql( + "SELECT _2, MIN(_1) + java_method('java.lang.Math', 'random') " + + "FROM tbl GROUP BY _2") + assert(getNumCometHashAggregate(df) == 1) + } } } } test("test final sum") { - Seq(false, true).foreach { dictionaryEnabled => - val n = 1 - withParquetTable((0L until 5L).map(i => (i, i % 2)), "tbl", dictionaryEnabled) { - checkSparkAnswerAndNumOfAggregates("SELECT _2, SUM(_1), MIN(_1) FROM tbl GROUP BY _2", n) - checkSparkAnswerAndNumOfAggregates("SELECT SUM(_1) FROM tbl", n) - checkSparkAnswerAndNumOfAggregates( - "SELECT _2, MIN(_1), MAX(_1), COUNT(_1), SUM(_1), AVG(_1) FROM tbl GROUP BY _2", - n) - checkSparkAnswerAndNumOfAggregates( - "SELECT MIN(_1), MIN(_2), MAX(_1), MAX(_2), COUNT(_1), COUNT(_2), SUM(_1), SUM(_2) FROM tbl", - n) + withSQLConf( + CometConf.COMET_EXEC_SHUFFLE_ENABLED.key -> "true", + CometConf.COMET_COLUMNAR_SHUFFLE_ENABLED.key -> "false") { + Seq(false, true).foreach { dictionaryEnabled => + withParquetTable((0L until 5L).map(i => (i, i % 2)), "tbl", dictionaryEnabled) { + checkSparkAnswerAndNumOfAggregates( + "SELECT _2, SUM(_1), MIN(_1) FROM tbl GROUP BY _2", + 2) + checkSparkAnswerAndNumOfAggregates("SELECT SUM(_1) FROM tbl", 2) + checkSparkAnswerAndNumOfAggregates( + "SELECT _2, MIN(_1), MAX(_1), COUNT(_1), SUM(_1), AVG(_1) FROM tbl GROUP BY _2", + 2) + checkSparkAnswerAndNumOfAggregates( + "SELECT MIN(_1), MIN(_2), MAX(_1), MAX(_2), COUNT(_1), COUNT(_2), SUM(_1), SUM(_2) FROM tbl", + 2) + } } } } test("test final avg") { - Seq(true, false).foreach { dictionaryEnabled => - withParquetTable( - (0 until 5).map(i => (i.toDouble, i.toDouble % 2)), - "tbl", - dictionaryEnabled) { - val n = 1 - checkSparkAnswerAndNumOfAggregates("SELECT _2 , AVG(_1) FROM tbl GROUP BY _2", n) - checkSparkAnswerAndNumOfAggregates("SELECT AVG(_1) FROM tbl", n) - checkSparkAnswerAndNumOfAggregates( - "SELECT _2, MIN(_1), MAX(_1), COUNT(_1), SUM(_1), AVG(_1) FROM tbl GROUP BY _2", - n) - checkSparkAnswerAndNumOfAggregates( - "SELECT MIN(_1), MIN(_2), MAX(_1), MAX(_2), COUNT(_1), COUNT(_2), SUM(_1), SUM(_2), AVG(_1), AVG(_2) FROM tbl", - n) + withSQLConf( + CometConf.COMET_EXEC_SHUFFLE_ENABLED.key -> "true", + CometConf.COMET_COLUMNAR_SHUFFLE_ENABLED.key -> "false") { + Seq(true, false).foreach { dictionaryEnabled => + withParquetTable( + (0 until 5).map(i => (i.toDouble, i.toDouble % 2)), + "tbl", + dictionaryEnabled) { + checkSparkAnswerAndNumOfAggregates("SELECT _2 , AVG(_1) FROM tbl GROUP BY _2", 2) + checkSparkAnswerAndNumOfAggregates("SELECT AVG(_1) FROM tbl", 2) + checkSparkAnswerAndNumOfAggregates( + "SELECT _2, MIN(_1), MAX(_1), COUNT(_1), SUM(_1), AVG(_1) FROM tbl GROUP BY _2", + 2) + checkSparkAnswerAndNumOfAggregates( + "SELECT MIN(_1), MIN(_2), MAX(_1), MAX(_2), COUNT(_1), COUNT(_2), SUM(_1), SUM(_2), AVG(_1), AVG(_2) FROM tbl", + 2) + } } } } @@ -542,31 +595,34 @@ class CometAggregateSuite extends CometTestBase with AdaptiveSparkPlanHelper { // TODO: enable decimal average for Spark 3.2 & 3.3 assume(isSpark34Plus) - Seq(true, false).foreach { dictionaryEnabled => - withSQLConf("parquet.enable.dictionary" -> dictionaryEnabled.toString) { - val table = "t1" - withTable(table) { - sql(s"create table $table(a decimal(38, 37), b INT) using parquet") - sql(s"insert into $table values(-0.0000000000000000000000000000000000002, 1)") - sql(s"insert into $table values(-0.0000000000000000000000000000000000002, 1)") - sql(s"insert into $table values(-0.0000000000000000000000000000000000004, 2)") - sql(s"insert into $table values(-0.0000000000000000000000000000000000004, 2)") - sql(s"insert into $table values(-0.00000000000000000000000000000000000002, 3)") - sql(s"insert into $table values(-0.00000000000000000000000000000000000002, 3)") - sql(s"insert into $table values(-0.00000000000000000000000000000000000004, 4)") - sql(s"insert into $table values(-0.00000000000000000000000000000000000004, 4)") - sql(s"insert into $table values(0.13344406545919155429936259114971302408, 5)") - sql(s"insert into $table values(0.13344406545919155429936259114971302408, 5)") - - val n = 1 - checkSparkAnswerAndNumOfAggregates("SELECT b , AVG(a) FROM t1 GROUP BY b", n) - checkSparkAnswerAndNumOfAggregates("SELECT AVG(a) FROM t1", n) - checkSparkAnswerAndNumOfAggregates( - "SELECT b, MIN(a), MAX(a), COUNT(a), SUM(a), AVG(a) FROM t1 GROUP BY b", - n) - checkSparkAnswerAndNumOfAggregates( - "SELECT MIN(a), MAX(a), COUNT(a), SUM(a), AVG(a) FROM t1", - n) + withSQLConf( + CometConf.COMET_EXEC_SHUFFLE_ENABLED.key -> "true", + CometConf.COMET_COLUMNAR_SHUFFLE_ENABLED.key -> "false") { + Seq(true, false).foreach { dictionaryEnabled => + withSQLConf("parquet.enable.dictionary" -> dictionaryEnabled.toString) { + val table = "t1" + withTable(table) { + sql(s"create table $table(a decimal(38, 37), b INT) using parquet") + sql(s"insert into $table values(-0.0000000000000000000000000000000000002, 1)") + sql(s"insert into $table values(-0.0000000000000000000000000000000000002, 1)") + sql(s"insert into $table values(-0.0000000000000000000000000000000000004, 2)") + sql(s"insert into $table values(-0.0000000000000000000000000000000000004, 2)") + sql(s"insert into $table values(-0.00000000000000000000000000000000000002, 3)") + sql(s"insert into $table values(-0.00000000000000000000000000000000000002, 3)") + sql(s"insert into $table values(-0.00000000000000000000000000000000000004, 4)") + sql(s"insert into $table values(-0.00000000000000000000000000000000000004, 4)") + sql(s"insert into $table values(0.13344406545919155429936259114971302408, 5)") + sql(s"insert into $table values(0.13344406545919155429936259114971302408, 5)") + + checkSparkAnswerAndNumOfAggregates("SELECT b , AVG(a) FROM t1 GROUP BY b", 2) + checkSparkAnswerAndNumOfAggregates("SELECT AVG(a) FROM t1", 2) + checkSparkAnswerAndNumOfAggregates( + "SELECT b, MIN(a), MAX(a), COUNT(a), SUM(a), AVG(a) FROM t1 GROUP BY b", + 2) + checkSparkAnswerAndNumOfAggregates( + "SELECT MIN(a), MAX(a), COUNT(a), SUM(a), AVG(a) FROM t1", + 2) + } } } } @@ -584,54 +640,62 @@ class CometAggregateSuite extends CometTestBase with AdaptiveSparkPlanHelper { } test("avg null handling") { - val table = "t1" - withTable(table) { - sql(s"create table $table(a double, b double) using parquet") - sql(s"insert into $table values(1, 1.0)") - sql(s"insert into $table values(null, null)") - sql(s"insert into $table values(1, 2.0)") - sql(s"insert into $table values(null, null)") - sql(s"insert into $table values(2, null)") - sql(s"insert into $table values(2, null)") - - val query = sql(s"select a, AVG(b) from $table GROUP BY a") - checkSparkAnswer(query) + withSQLConf( + CometConf.COMET_EXEC_SHUFFLE_ENABLED.key -> "true", + CometConf.COMET_COLUMNAR_SHUFFLE_ENABLED.key -> "false") { + val table = "t1" + withTable(table) { + sql(s"create table $table(a double, b double) using parquet") + sql(s"insert into $table values(1, 1.0)") + sql(s"insert into $table values(null, null)") + sql(s"insert into $table values(1, 2.0)") + sql(s"insert into $table values(null, null)") + sql(s"insert into $table values(2, null)") + sql(s"insert into $table values(2, null)") + + val query = sql(s"select a, AVG(b) from $table GROUP BY a") + checkSparkAnswer(query) + } } } test("Decimal Avg with DF") { Seq(true, false).foreach { dictionaryEnabled => - withTempDir { dir => - val path = new Path(dir.toURI.toString, "test") - makeParquetFile(path, 1000, 20, dictionaryEnabled) - withParquetTable(path.toUri.toString, "tbl") { - val expectedNumOfCometAggregates = 1 - - checkSparkAnswerAndNumOfAggregates( - "SELECT _g2, AVG(_7) FROM tbl GROUP BY _g2", - expectedNumOfCometAggregates) - - checkSparkAnswerWithTol("SELECT _g3, AVG(_8) FROM tbl GROUP BY _g3") - assert( - getNumCometHashAggregate( - sql("SELECT _g3, AVG(_8) FROM tbl GROUP BY _g3")) == expectedNumOfCometAggregates) - - checkSparkAnswerAndNumOfAggregates( - "SELECT _g4, AVG(_9) FROM tbl GROUP BY _g4", - expectedNumOfCometAggregates) - - checkSparkAnswerAndNumOfAggregates( - "SELECT AVG(_7) FROM tbl", - expectedNumOfCometAggregates) - - checkSparkAnswerWithTol("SELECT AVG(_8) FROM tbl") - assert( - getNumCometHashAggregate( - sql("SELECT AVG(_8) FROM tbl")) == expectedNumOfCometAggregates) - - checkSparkAnswerAndNumOfAggregates( - "SELECT AVG(_9) FROM tbl", - expectedNumOfCometAggregates) + Seq(true, false).foreach { nativeShuffleEnabled => + withSQLConf( + CometConf.COMET_EXEC_SHUFFLE_ENABLED.key -> nativeShuffleEnabled.toString, + CometConf.COMET_COLUMNAR_SHUFFLE_ENABLED.key -> "false") { + withTempDir { dir => + val path = new Path(dir.toURI.toString, "test") + makeParquetFile(path, 1000, 20, dictionaryEnabled) + withParquetTable(path.toUri.toString, "tbl") { + val expectedNumOfCometAggregates = if (nativeShuffleEnabled) 2 else 1 + + checkSparkAnswerAndNumOfAggregates( + "SELECT _g2, AVG(_7) FROM tbl GROUP BY _g2", + expectedNumOfCometAggregates) + + checkSparkAnswerWithTol("SELECT _g3, AVG(_8) FROM tbl GROUP BY _g3") + assert(getNumCometHashAggregate( + sql("SELECT _g3, AVG(_8) FROM tbl GROUP BY _g3")) == expectedNumOfCometAggregates) + + checkSparkAnswerAndNumOfAggregates( + "SELECT _g4, AVG(_9) FROM tbl GROUP BY _g4", + expectedNumOfCometAggregates) + + checkSparkAnswerAndNumOfAggregates( + "SELECT AVG(_7) FROM tbl", + expectedNumOfCometAggregates) + + checkSparkAnswerWithTol("SELECT AVG(_8) FROM tbl") + assert(getNumCometHashAggregate( + sql("SELECT AVG(_8) FROM tbl")) == expectedNumOfCometAggregates) + + checkSparkAnswerAndNumOfAggregates( + "SELECT AVG(_9) FROM tbl", + expectedNumOfCometAggregates) + } + } } } } diff --git a/spark/src/test/scala/org/apache/comet/exec/CometExecSuite.scala b/spark/src/test/scala/org/apache/comet/exec/CometExecSuite.scala index c8e31ef4c..1334f2f77 100644 --- a/spark/src/test/scala/org/apache/comet/exec/CometExecSuite.scala +++ b/spark/src/test/scala/org/apache/comet/exec/CometExecSuite.scala @@ -21,16 +21,20 @@ package org.apache.comet.exec import scala.util.Random +import org.scalactic.source.Position +import org.scalatest.Tag + import org.apache.hadoop.fs.Path import org.apache.spark.sql.{AnalysisException, Column, CometTestBase, DataFrame, DataFrameWriter, Row, SaveMode} import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.catalyst.catalog.{BucketSpec, CatalogStatistics, CatalogTable} import org.apache.spark.sql.catalyst.expressions.Hex import org.apache.spark.sql.comet.{CometFilterExec, CometHashAggregateExec, CometProjectExec, CometScanExec} -import org.apache.spark.sql.execution._ -import org.apache.spark.sql.execution.aggregate.HashAggregateExec -import org.apache.spark.sql.execution.exchange.{BroadcastExchangeExec, ShuffleExchangeExec} +import org.apache.spark.sql.comet.execution.shuffle.{CometColumnarShuffle, CometShuffleExchangeExec} +import org.apache.spark.sql.execution.{CollectLimitExec, UnionExec} +import org.apache.spark.sql.execution.exchange.BroadcastExchangeExec import org.apache.spark.sql.execution.joins.{BroadcastNestedLoopJoinExec, CartesianProductExec, SortMergeJoinExec} +import org.apache.spark.sql.functions.{date_add, expr} import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.internal.SQLConf.SESSION_LOCAL_TIMEZONE import org.apache.spark.unsafe.types.UTF8String @@ -40,6 +44,15 @@ import org.apache.comet.CometConf class CometExecSuite extends CometTestBase { import testImplicits._ + override protected def test(testName: String, testTags: Tag*)(testFun: => Any)(implicit + pos: Position): Unit = { + super.test(testName, testTags: _*) { + withSQLConf(CometConf.COMET_EXEC_SHUFFLE_ENABLED.key -> "true") { + testFun + } + } + } + test("scalar subquery") { val dataTypes = Seq( @@ -56,41 +69,45 @@ class CometExecSuite extends CometTestBase { "BINARY", "DECIMAL(38, 10)") dataTypes.map { subqueryType => - withParquetTable((0 until 5).map(i => (i, i + 1)), "tbl") { - var column1 = s"CAST(max(_1) AS $subqueryType)" - if (subqueryType == "BINARY") { - // arrow-rs doesn't support casting integer to binary yet. - // We added it to upstream but it's not released yet. - column1 = "CAST(CAST(max(_1) AS STRING) AS BINARY)" - } + withSQLConf( + CometConf.COMET_EXEC_SHUFFLE_ENABLED.key -> "true", + CometConf.COMET_COLUMNAR_SHUFFLE_ENABLED.key -> "true") { + withParquetTable((0 until 5).map(i => (i, i + 1)), "tbl") { + var column1 = s"CAST(max(_1) AS $subqueryType)" + if (subqueryType == "BINARY") { + // arrow-rs doesn't support casting integer to binary yet. + // We added it to upstream but it's not released yet. + column1 = "CAST(CAST(max(_1) AS STRING) AS BINARY)" + } - val df1 = sql(s"SELECT (SELECT $column1 FROM tbl) AS a, _1, _2 FROM tbl") - checkSparkAnswerAndOperator(df1) + val df1 = sql(s"SELECT (SELECT $column1 FROM tbl) AS a, _1, _2 FROM tbl") + checkSparkAnswerAndOperator(df1) - var column2 = s"CAST(_1 AS $subqueryType)" - if (subqueryType == "BINARY") { - // arrow-rs doesn't support casting integer to binary yet. - // We added it to upstream but it's not released yet. - column2 = "CAST(CAST(_1 AS STRING) AS BINARY)" - } + var column2 = s"CAST(_1 AS $subqueryType)" + if (subqueryType == "BINARY") { + // arrow-rs doesn't support casting integer to binary yet. + // We added it to upstream but it's not released yet. + column2 = "CAST(CAST(_1 AS STRING) AS BINARY)" + } - val df2 = sql(s"SELECT _1, _2 FROM tbl WHERE $column2 > (SELECT $column1 FROM tbl)") - checkSparkAnswerAndOperator(df2) + val df2 = sql(s"SELECT _1, _2 FROM tbl WHERE $column2 > (SELECT $column1 FROM tbl)") + checkSparkAnswerAndOperator(df2) - // Non-correlated exists subquery will be rewritten to scalar subquery - val df3 = sql( - "SELECT * FROM tbl WHERE EXISTS " + - s"(SELECT $column2 FROM tbl WHERE _1 > 1)") - checkSparkAnswerAndOperator(df3) + // Non-correlated exists subquery will be rewritten to scalar subquery + val df3 = sql( + "SELECT * FROM tbl WHERE EXISTS " + + s"(SELECT $column2 FROM tbl WHERE _1 > 1)") + checkSparkAnswerAndOperator(df3) - // Null value - column1 = s"CAST(NULL AS $subqueryType)" - if (subqueryType == "BINARY") { - column1 = "CAST(CAST(NULL AS STRING) AS BINARY)" - } + // Null value + column1 = s"CAST(NULL AS $subqueryType)" + if (subqueryType == "BINARY") { + column1 = "CAST(CAST(NULL AS STRING) AS BINARY)" + } - val df4 = sql(s"SELECT (SELECT $column1 FROM tbl LIMIT 1) AS a, _1, _2 FROM tbl") - checkSparkAnswerAndOperator(df4) + val df4 = sql(s"SELECT (SELECT $column1 FROM tbl LIMIT 1) AS a, _1, _2 FROM tbl") + checkSparkAnswerAndOperator(df4) + } } } } @@ -122,6 +139,68 @@ class CometExecSuite extends CometTestBase { } } + test( + "fix: ReusedExchangeExec + CometShuffleExchangeExec under QueryStageExec " + + "should be CometRoot") { + val tableName = "table1" + val dim = "dim" + + withSQLConf( + SQLConf.EXCHANGE_REUSE_ENABLED.key -> "true", + SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1", + CometConf.COMET_EXEC_SHUFFLE_ENABLED.key -> "true", + CometConf.COMET_COLUMNAR_SHUFFLE_ENABLED.key -> "true") { + withTable(tableName, dim) { + + sql( + s"CREATE TABLE $tableName (id BIGINT, price FLOAT, date DATE, ts TIMESTAMP) USING parquet " + + "PARTITIONED BY (id)") + sql(s"CREATE TABLE $dim (id BIGINT, date DATE) USING parquet") + + spark + .range(1, 100) + .withColumn("date", date_add(expr("DATE '1970-01-01'"), expr("CAST(id % 4 AS INT)"))) + .withColumn("ts", expr("TO_TIMESTAMP(date)")) + .withColumn("price", expr("CAST(id AS FLOAT)")) + .select("id", "price", "date", "ts") + .coalesce(1) + .write + .mode(SaveMode.Append) + .partitionBy("id") + .saveAsTable(tableName) + + spark + .range(1, 10) + .withColumn("date", expr("DATE '1970-01-02'")) + .select("id", "date") + .coalesce(1) + .write + .mode(SaveMode.Append) + .saveAsTable(dim) + + val query = + s""" + |SELECT $tableName.id, sum(price) as sum_price + |FROM $tableName, $dim + |WHERE $tableName.id = $dim.id AND $tableName.date = $dim.date + |GROUP BY $tableName.id HAVING sum(price) > ( + | SELECT sum(price) * 0.0001 FROM $tableName, $dim WHERE $tableName.id = $dim.id AND $tableName.date = $dim.date + | ) + |ORDER BY sum_price + |""".stripMargin + + val df = sql(query) + checkSparkAnswer(df) + val exchanges = stripAQEPlan(df.queryExecution.executedPlan).collect { + case s: CometShuffleExchangeExec if s.shuffleType == CometColumnarShuffle => + s + s + } + assert(exchanges.length == 4) + } + } + } + test("expand operator") { val data1 = (0 until 1000) .map(_ % 5) // reduce value space to trigger dictionary encoding @@ -131,7 +210,7 @@ class CometExecSuite extends CometTestBase { Seq(data1, data2).foreach { tableData => withParquetTable(tableData, "tbl") { val df = sql("SELECT _1, _2, SUM(_3) FROM tbl GROUP BY _1, _2 GROUPING SETS ((_1), (_2))") - checkSparkAnswerAndOperator(df, classOf[HashAggregateExec], classOf[ShuffleExchangeExec]) + checkSparkAnswerAndOperator(df) } } } @@ -265,12 +344,14 @@ class CometExecSuite extends CometTestBase { } test("final aggregation") { - withParquetTable( - (0 until 100) - .map(_ => (Random.nextInt(), Random.nextInt() % 5)), - "tbl") { - val df = sql("SELECT _2, COUNT(*) FROM tbl GROUP BY _2") - checkSparkAnswerAndOperator(df, classOf[HashAggregateExec], classOf[ShuffleExchangeExec]) + withSQLConf(CometConf.COMET_EXEC_SHUFFLE_ENABLED.key -> "true") { + withParquetTable( + (0 until 100) + .map(_ => (Random.nextInt(), Random.nextInt() % 5)), + "tbl") { + val df = sql("SELECT _2, COUNT(*) FROM tbl GROUP BY _2") + checkSparkAnswerAndOperator(df) + } } } @@ -281,6 +362,18 @@ class CometExecSuite extends CometTestBase { } } + test("global sort (columnar shuffle only)") { + withSQLConf( + SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "false", + CometConf.COMET_EXEC_SHUFFLE_ENABLED.key -> "true", + CometConf.COMET_COLUMNAR_SHUFFLE_ENABLED.key -> "true") { + withParquetTable((0 until 5).map(i => (i, i + 1)), "tbl") { + val df = sql("SELECT * FROM tbl").sort($"_1".desc) + checkSparkAnswerAndOperator(df) + } + } + } + test("spill sort with (multiple) dictionaries") { withSQLConf(CometConf.COMET_MEMORY_OVERHEAD.key -> "15MB") { withTempDir { dir => @@ -320,6 +413,22 @@ class CometExecSuite extends CometTestBase { } } + test("limit") { + Seq("true", "false").foreach { columnarShuffle => + withSQLConf( + CometConf.COMET_EXEC_SHUFFLE_ENABLED.key -> "true", + CometConf.COMET_COLUMNAR_SHUFFLE_ENABLED.key -> columnarShuffle) { + withParquetTable((0 until 5).map(i => (i, i + 1)), "tbl_a") { + val df = sql("SELECT * FROM tbl_a") + .repartition(10, $"_1") + .limit(2) + .sort($"_2".desc) + checkSparkAnswerAndOperator(df) + } + } + } + } + test("limit (cartesian product)") { withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1") { withParquetTable((0 until 5).map(i => (i, i + 1)), "tbl_a") { @@ -384,7 +493,7 @@ class CometExecSuite extends CometTestBase { .write .saveAsTable("t1") val query = sql("SELECT count(1) FROM t1") - checkSparkAnswerAndOperator(query, classOf[HashAggregateExec], classOf[ShuffleExchangeExec]) + checkSparkAnswerAndOperator(query) } } @@ -526,10 +635,7 @@ class CometExecSuite extends CometTestBase { "parquet.enable.dictionary" -> dictionary) { withParquetTable(Seq((Long.MaxValue, 1), (Long.MaxValue, 2)), "tbl") { val df = sql("SELECT sum(_1) FROM tbl") - checkSparkAnswerAndOperator( - df, - classOf[HashAggregateExec], - classOf[ShuffleExchangeExec]) + checkSparkAnswerAndOperator(df) } } } @@ -635,11 +741,19 @@ class CometExecSuite extends CometTestBase { val df2 = (0 until 50).map(i => (i % 7, i % 11, i.toString)).toDF("i", "j", "k").as("df2") - val BucketedTableTestSpec(bucketSpecLeft, numPartitionsLeft, _, _, _) = - bucketedTableTestSpecLeft + val BucketedTableTestSpec( + bucketSpecLeft, + numPartitionsLeft, + shuffleLeft, + sortLeft, + numOutputPartitionsLeft) = bucketedTableTestSpecLeft - val BucketedTableTestSpec(bucketSpecRight, numPartitionsRight, _, _, _) = - bucketedTableTestSpecRight + val BucketedTableTestSpec( + bucketSpecRight, + numPartitionsRight, + shuffleRight, + sortRight, + numOutputPartitionsRight) = bucketedTableTestSpecRight withTable("bucketed_table1", "bucketed_table2") { withBucket(df1.repartition(numPartitionsLeft).write.format("parquet"), bucketSpecLeft) @@ -749,12 +863,7 @@ class CometExecSuite extends CometTestBase { assert(rdd.partitions.length == 10) val coalesced = df.coalesce(2).select($"l" + 1).sortWithinPartitions($"l") - checkSparkAnswerAndOperator( - coalesced, - classOf[ProjectExec], - classOf[SortExec], - classOf[CoalesceExec], - classOf[ShuffleExchangeExec]) + checkSparkAnswerAndOperator(coalesced) } } @@ -773,25 +882,27 @@ class CometExecSuite extends CometTestBase { } test("coalesce") { - withTable("t1") { - (0 until 5) - .map(i => (i, (i + 1).toLong)) - .toDF("l", "b") - .write - .saveAsTable("t1") + withSQLConf(CometConf.COMET_EXEC_SHUFFLE_ENABLED.key -> "true") { + withTable("t1") { + (0 until 5) + .map(i => (i, (i + 1).toLong)) + .toDF("l", "b") + .write + .saveAsTable("t1") - val df = sql("SELECT * FROM t1") - .sortWithinPartitions($"l".desc) - .repartition(10, $"l") + val df = sql("SELECT * FROM t1") + .sortWithinPartitions($"l".desc) + .repartition(10, $"l") - val rdd = df.rdd - assert(rdd.getNumPartitions == 10) + val rdd = df.rdd + assert(rdd.getNumPartitions == 10) - val coalesced = df.coalesce(2) - checkSparkAnswerAndOperator(coalesced, classOf[CoalesceExec], classOf[ShuffleExchangeExec]) + val coalesced = df.coalesce(2) + checkSparkAnswerAndOperator(coalesced) - val coalescedRdd = coalesced.rdd - assert(coalescedRdd.getNumPartitions == 2) + val coalescedRdd = coalesced.rdd + assert(coalescedRdd.getNumPartitions == 2) + } } } } diff --git a/spark/src/test/scala/org/apache/comet/exec/CometShuffleSuite.scala b/spark/src/test/scala/org/apache/comet/exec/CometShuffleSuite.scala new file mode 100644 index 000000000..db78bc27f --- /dev/null +++ b/spark/src/test/scala/org/apache/comet/exec/CometShuffleSuite.scala @@ -0,0 +1,843 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.comet.exec + +import org.scalactic.source.Position +import org.scalatest.Tag + +import org.apache.hadoop.fs.Path +import org.apache.spark.{Partitioner, SparkConf} +import org.apache.spark.sql.{CometTestBase, Row} +import org.apache.spark.sql.comet.execution.shuffle.{CometShuffleDependency, CometShuffleExchangeExec, CometShuffleManager} +import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.types._ + +import org.apache.comet.CometConf +import org.apache.comet.CometSparkSessionExtensions.isSpark34Plus + +abstract class CometShuffleSuiteBase extends CometTestBase with AdaptiveSparkPlanHelper { + + protected val adaptiveExecutionEnabled: Boolean + + protected val fastMergeEnabled: Boolean = true + + protected val numElementsForceSpillThreshold: Int = 10 + + protected val encryptionEnabled: Boolean = false + + override protected def sparkConf: SparkConf = { + val conf = super.sparkConf + conf + .set(SQLConf.ADAPTIVE_EXECUTION_ENABLED.key, adaptiveExecutionEnabled.toString) + .set("spark.shuffle.unsafe.fastMergeEnabled", fastMergeEnabled.toString) + .set("spark.io.encryption.enabled", encryptionEnabled.toString) + } + + protected val asyncShuffleEnable: Boolean + + override protected def test(testName: String, testTags: Tag*)(testFun: => Any)(implicit + pos: Position): Unit = { + super.test(testName, testTags: _*) { + withSQLConf( + CometConf.COMET_COLUMNAR_SHUFFLE_ASYNC_ENABLED.key -> asyncShuffleEnable.toString, + CometConf.COMET_EXEC_SHUFFLE_SPILL_THRESHOLD.key -> numElementsForceSpillThreshold.toString) { + testFun + } + } + } + + import testImplicits._ + + test("columnar shuffle on array") { + Seq(10, 201).foreach { numPartitions => + Seq("1.0", "10.0").foreach { ratio => + withSQLConf( + CometConf.COMET_EXEC_ENABLED.key -> "false", + CometConf.COMET_COLUMNAR_SHUFFLE_ENABLED.key -> "true", + CometConf.COMET_EXEC_SHUFFLE_ENABLED.key -> "true", + CometConf.COMET_SHUFFLE_PREFER_DICTIONARY_RATIO.key -> ratio) { + withParquetTable( + (0 until 50).map(i => + ( + Seq(i + 1, i + 2, i + 3), + Seq(i.toLong, (i + 2).toLong, (i + 5).toLong), + Seq(i.toString, (i + 3).toString, (i + 2).toString), + Seq( + ( + i + 1, + Seq(i + 3, i + 1, i + 2), // nested array in struct + Seq(i.toLong, (i + 2).toLong, (i + 5).toLong), + Seq(i.toString, (i + 3).toString, (i + 2).toString), + (i + 2).toString)), + i + 1)), + "tbl") { + val df = sql("SELECT * FROM tbl") + .filter($"_5" > 10) + .repartition(numPartitions, $"_1", $"_2", $"_3", $"_4", $"_5") + .sortWithinPartitions($"_1") + + checkSparkAnswer(df) + checkCometExchange(df, 1, false) + } + } + } + } + } + + test("columnar shuffle on nested array") { + Seq("false", "true").foreach { execEnabled => + Seq(10, 201).foreach { numPartitions => + Seq("1.0", "10.0").foreach { ratio => + withSQLConf( + CometConf.COMET_EXEC_ENABLED.key -> execEnabled, + CometConf.COMET_COLUMNAR_SHUFFLE_ENABLED.key -> "true", + CometConf.COMET_EXEC_SHUFFLE_ENABLED.key -> "true", + CometConf.COMET_SHUFFLE_PREFER_DICTIONARY_RATIO.key -> ratio) { + withParquetTable( + (0 until 50).map(i => (Seq(Seq(i + 1), Seq(i + 2), Seq(i + 3)), i + 1)), + "tbl") { + val df = sql("SELECT * FROM tbl") + .filter($"_2" > 10) + .repartition(numPartitions, $"_1", $"_2") + .sortWithinPartitions($"_1") + + checkSparkAnswer(df) + // Nested array fallback to Spark shuffle for now + checkCometExchange(df, 0, false) + } + } + } + } + } + } + + test("columnar shuffle on nested struct") { + Seq(10, 201).foreach { numPartitions => + Seq("1.0", "10.0").foreach { ratio => + withSQLConf( + CometConf.COMET_EXEC_ENABLED.key -> "false", + CometConf.COMET_COLUMNAR_SHUFFLE_ENABLED.key -> "true", + CometConf.COMET_EXEC_SHUFFLE_ENABLED.key -> "true", + CometConf.COMET_SHUFFLE_PREFER_DICTIONARY_RATIO.key -> ratio) { + withParquetTable( + (0 until 50).map(i => + ((i, 2.toString, (i + 1).toLong, (3.toString, i + 1, (i + 2).toLong)), i + 1)), + "tbl") { + val df = sql("SELECT * FROM tbl") + .filter($"_2" > 10) + .repartition(numPartitions, $"_1", $"_2") + .sortWithinPartitions($"_1") + + checkSparkAnswer(df) + checkCometExchange(df, 1, false) + } + } + } + } + } + + test("fix: Dictionary arrays imported from native should not be overridden") { + Seq(10, 201).foreach { numPartitions => + withSQLConf( + CometConf.COMET_EXEC_ENABLED.key -> "true", + CometConf.COMET_BATCH_SIZE.key -> "10", + CometConf.COMET_EXEC_ALL_OPERATOR_ENABLED.key -> "true", + CometConf.COMET_EXEC_ALL_EXPR_ENABLED.key -> "true", + CometConf.COMET_COLUMNAR_SHUFFLE_ENABLED.key -> "false", + CometConf.COMET_EXEC_SHUFFLE_ENABLED.key -> "true") { + withParquetTable((0 until 50).map(i => (1.toString, 2.toString, (i + 1).toLong)), "tbl") { + val df = sql("SELECT * FROM tbl") + .filter($"_1" === 1.toString) + .repartition(numPartitions, $"_1", $"_2") + .sortWithinPartitions($"_1") + checkSparkAnswerAndOperator(df) + } + } + } + } + + test("fix: closing sliced dictionary Comet vector should not close dictionary array") { + (0 to 10).foreach { _ => + withSQLConf( + SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1", + CometConf.COMET_BATCH_SIZE.key -> "10", + CometConf.COMET_EXEC_ENABLED.key -> "false", + CometConf.COMET_SHUFFLE_PREFER_DICTIONARY_RATIO.key -> "1.1", + CometConf.COMET_COLUMNAR_SHUFFLE_ENABLED.key -> "true", + CometConf.COMET_EXEC_SHUFFLE_SPILL_THRESHOLD.key -> "1000000000", + CometConf.COMET_EXEC_SHUFFLE_ENABLED.key -> "true") { + val table1 = (0 until 1000) + .map(i => (111111.toString, 2222222.toString, 3333333.toString, i.toLong)) + .toDF("a", "b", "c", "d") + val table2 = (0 until 1000) + .map(i => (3333333.toString, 2222222.toString, 111111.toString, i.toLong)) + .toDF("e", "f", "g", "h") + withParquetTable(table1, "tbl_a") { + withParquetTable(table2, "tbl_b") { + val df = sql( + "select a, b, count(distinct h) from tbl_a, tbl_b " + + "where c = e and b = '2222222' and a not like '2' group by a, b") + checkSparkAnswer(df) + } + } + } + } + } + + test("fix: Dictionary field should have distinct dict_id") { + Seq(10, 201).foreach { numPartitions => + withSQLConf( + CometConf.COMET_EXEC_ENABLED.key -> "false", + CometConf.COMET_SHUFFLE_PREFER_DICTIONARY_RATIO.key -> "2.0", + CometConf.COMET_COLUMNAR_SHUFFLE_ENABLED.key -> "true", + CometConf.COMET_EXEC_SHUFFLE_ENABLED.key -> "true") { + withParquetTable( + (0 until 10000).map(i => (1.toString, 2.toString, (i + 1).toLong)), + "tbl") { + assert( + sql("SELECT * FROM tbl").repartition(numPartitions, $"_1", $"_2").count() == sql( + "SELECT * FROM tbl") + .count()) + val shuffled = sql("SELECT * FROM tbl").repartition(numPartitions, $"_1") + checkSparkAnswer(shuffled) + } + } + } + } + + test("dictionary shuffle") { + Seq(10, 201).foreach { numPartitions => + withSQLConf( + CometConf.COMET_EXEC_ENABLED.key -> "false", + CometConf.COMET_SHUFFLE_PREFER_DICTIONARY_RATIO.key -> "2.0", + CometConf.COMET_COLUMNAR_SHUFFLE_ENABLED.key -> "true", + CometConf.COMET_EXEC_SHUFFLE_ENABLED.key -> "true") { + withParquetTable((0 until 10000).map(i => (1.toString, (i + 1).toLong)), "tbl") { + assert( + sql("SELECT * FROM tbl").repartition(numPartitions, $"_1").count() == sql( + "SELECT * FROM tbl") + .count()) + val shuffled = sql("SELECT * FROM tbl").select($"_1").repartition(numPartitions, $"_1") + checkSparkAnswer(shuffled) + } + } + } + } + + test("dictionary shuffle: fallback to string") { + Seq(10, 201).foreach { numPartitions => + withSQLConf( + CometConf.COMET_EXEC_ENABLED.key -> "false", + CometConf.COMET_SHUFFLE_PREFER_DICTIONARY_RATIO.key -> "1000000000.0", + CometConf.COMET_COLUMNAR_SHUFFLE_ENABLED.key -> "true", + CometConf.COMET_EXEC_SHUFFLE_ENABLED.key -> "true") { + withParquetTable((0 until 10000).map(i => (1.toString, (i + 1).toLong)), "tbl") { + assert( + sql("SELECT * FROM tbl").repartition(numPartitions, $"_1").count() == sql( + "SELECT * FROM tbl") + .count()) + val shuffled = sql("SELECT * FROM tbl").select($"_1").repartition(numPartitions, $"_1") + checkSparkAnswer(shuffled) + } + } + } + } + + test("fix: inMemSorter should be reset after spilling") { + withSQLConf( + CometConf.COMET_EXEC_ENABLED.key -> "false", + CometConf.COMET_COLUMNAR_SHUFFLE_ENABLED.key -> "true", + CometConf.COMET_EXEC_SHUFFLE_ENABLED.key -> "true") { + withParquetTable((0 until 10000).map(i => (1, (i + 1).toLong)), "tbl") { + assert( + sql("SELECT * FROM tbl").repartition(201, $"_1").count() == sql("SELECT * FROM tbl") + .count()) + } + } + } + + test("fix: native Unsafe row accessors return incorrect results") { + Seq(10, 201).foreach { numPartitions => + withTempDir { dir => + val path = new Path(dir.toURI.toString, "test.parquet") + makeParquetFileAllTypes(path, false, 10000, 10010) + + Seq( + $"_1", + $"_2", + $"_3", + $"_4", + $"_5", + $"_6", + $"_7", + $"_8", + $"_13", + $"_14", + $"_15", + $"_16", + $"_17", + $"_18", + $"_19", + $"_20").foreach { col => + withSQLConf( + CometConf.COMET_EXEC_ENABLED.key -> "false", + CometConf.COMET_COLUMNAR_SHUFFLE_ENABLED.key -> "true", + CometConf.COMET_EXEC_SHUFFLE_ENABLED.key -> "true") { + readParquetFile(path.toString) { df => + val shuffled = df.select(col).repartition(numPartitions, col) + checkSparkAnswer(shuffled) + } + } + } + } + } + } + + test("fix: StreamReader should always set useDecimal128 as true") { + Seq(10, 201).foreach { numPartitions => + withSQLConf( + CometConf.COMET_EXEC_ENABLED.key -> "true", + CometConf.COMET_COLUMNAR_SHUFFLE_ENABLED.key -> "true", + CometConf.COMET_EXEC_SHUFFLE_ENABLED.key -> "true") { + withTempPath { dir => + val data = makeDecimalRDD(1000, DecimalType(12, 2), false) + data.write.parquet(dir.getCanonicalPath) + readParquetFile(dir.getCanonicalPath) { df => + { + val shuffled = df.repartition(numPartitions, $"dec") + checkSparkAnswer(shuffled) + } + } + } + } + } + } + + test("fix: Native Unsafe decimal accessors return incorrect results") { + Seq(10, 201).foreach { numPartitions => + withSQLConf( + CometConf.COMET_EXEC_ENABLED.key -> "true", + CometConf.COMET_COLUMNAR_SHUFFLE_ENABLED.key -> "true", + CometConf.COMET_EXEC_SHUFFLE_ENABLED.key -> "true") { + withTempPath { dir => + val data = makeDecimalRDD(1000, DecimalType(22, 2), false) + data.write.parquet(dir.getCanonicalPath) + readParquetFile(dir.getCanonicalPath) { df => + { + val shuffled = df.repartition(numPartitions, $"dec") + checkSparkAnswer(shuffled) + } + } + } + } + } + } + + test("Comet shuffle reader should respect spark.comet.batchSize") { + Seq(10, 201).foreach { numPartitions => + withSQLConf( + CometConf.COMET_EXEC_ENABLED.key -> "true", + CometConf.COMET_COLUMNAR_SHUFFLE_ENABLED.key -> "true", + CometConf.COMET_EXEC_SHUFFLE_ENABLED.key -> "true") { + withParquetTable((0 until 10000).map(i => (1, (i + 1).toLong)), "tbl") { + assert( + sql("SELECT * FROM tbl").repartition(numPartitions, $"_1").count() == sql( + "SELECT * FROM tbl").count()) + } + } + } + } + + test("Arrow shuffle should work with BatchScan") { + withSQLConf( + SQLConf.USE_V1_SOURCE_LIST.key -> "", // Use DataSourceV2 + SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "false", // Disable AQE + CometConf.COMET_SCAN_ENABLED.key -> "false", // Disable CometScan to use Spark BatchScan + CometConf.COMET_EXEC_ENABLED.key -> "false", + CometConf.COMET_EXEC_SHUFFLE_ENABLED.key -> "true", + CometConf.COMET_COLUMNAR_SHUFFLE_ENABLED.key -> "true") { + withParquetTable((0 until 5).map(i => (i, (i + 1).toLong)), "tbl") { + val df = sql("SELECT * FROM tbl") + val shuffled = df + .select($"_1" + 1 as ("a")) + .filter($"a" > 4) + .repartitionByRange(10, $"_2") + .limit(2) + .repartition(10, $"_1") + + checkAnswer(shuffled, Row(5) :: Nil) + } + } + } + + test("Columnar shuffle for large shuffle partition number") { + Seq(10, 200, 201).foreach { numPartitions => + withSQLConf( + CometConf.COMET_EXEC_ENABLED.key -> "false", + CometConf.COMET_EXEC_SHUFFLE_ENABLED.key -> "true", + CometConf.COMET_COLUMNAR_SHUFFLE_ENABLED.key -> "true") { + withParquetTable((0 until 5).map(i => (i, (i + 1).toLong)), "tbl") { + val df = sql("SELECT * FROM tbl") + + val shuffled = df.repartitionByRange(numPartitions, $"_2") + + val cometShuffleExecs = checkCometExchange(shuffled, 1, false) + // `CometSerializedShuffleHandle` is used for large shuffle partition number, + // i.e., sort-based shuffle writer + cometShuffleExecs(0).shuffleDependency.shuffleHandle.getClass.getName + .contains("CometSerializedShuffleHandle") + + checkSparkAnswer(shuffled) + } + } + } + } + + test("grouped aggregate: Comet shuffle") { + withSQLConf(CometConf.COMET_EXEC_SHUFFLE_ENABLED.key -> "true") { + withParquetTable((0 until 5).map(i => (i, i + 1)), "tbl") { + val df = sql("SELECT count(_2), sum(_2) FROM tbl GROUP BY _1") + checkCometExchange(df, 1, true) + checkSparkAnswerAndOperator(df) + } + } + } + + test("hash shuffle: Comet shuffle") { + // Disable CometExec to explicit test Comet Arrow shuffle path + Seq(true, false).foreach { execEnabled => + withSQLConf( + CometConf.COMET_EXEC_ENABLED.key -> execEnabled.toString, + CometConf.COMET_EXEC_SHUFFLE_ENABLED.key -> "true", + CometConf.COMET_COLUMNAR_SHUFFLE_ENABLED.key -> (!execEnabled).toString) { + withParquetTable((0 until 5).map(i => (i, (i + 1).toLong)), "tbl") { + val df = sql("SELECT * FROM tbl").sortWithinPartitions($"_1".desc) + val shuffled1 = df.repartition(10, $"_1") + + // If Comet execution is disabled, `Sort` operator is Spark operator + // and jvm arrow shuffle is applied. + checkCometExchange(shuffled1, 1, execEnabled) + checkSparkAnswer(shuffled1) + + val shuffled2 = df.repartition(10, $"_1", $"_2") + + checkCometExchange(shuffled2, 1, execEnabled) + checkSparkAnswer(shuffled2) + + val shuffled3 = df.repartition(10, $"_2", $"_1") + + checkCometExchange(shuffled3, 1, execEnabled) + checkSparkAnswer(shuffled3) + } + } + } + } + + test("Comet shuffle: different data type") { + // Disable CometExec to explicit test Comet native shuffle path + Seq(true, false).foreach { execEnabled => + Seq(true, false).foreach { dictionaryEnabled => + withTempDir { dir => + val path = new Path(dir.toURI.toString, "test.parquet") + makeParquetFileAllTypes(path, dictionaryEnabled = dictionaryEnabled, 10000) + val all_types = if (isSpark34Plus) { + Seq( + $"_1", + $"_2", + $"_3", + $"_4", + $"_5", + $"_6", + $"_7", + $"_8", + $"_9", + $"_10", + $"_11", + $"_13", + $"_14", + $"_15", + $"_16", + $"_17", + $"_18", + $"_19", + $"_20") + } else { + Seq( + $"_1", + $"_2", + $"_3", + $"_4", + $"_5", + $"_6", + $"_7", + $"_8", + $"_9", + $"_10", + $"_11", + $"_13", + $"_15", + $"_16", + $"_18", + $"_19", + $"_20") + } + all_types.foreach { col => + withSQLConf( + CometConf.COMET_EXEC_ENABLED.key -> execEnabled.toString, + CometConf.COMET_EXEC_SHUFFLE_ENABLED.key -> "true", + "parquet.enable.dictionary" -> dictionaryEnabled.toString) { + readParquetFile(path.toString) { df => + val shuffled = df + .select($"_1") + .repartition(10, col) + checkCometExchange(shuffled, 1, true) + if (execEnabled) { + checkSparkAnswerAndOperator(shuffled) + } else { + checkSparkAnswer(shuffled) + } + } + } + } + } + } + } + } + + test("hash shuffle: Comet columnar shuffle") { + Seq(10, 200, 201).foreach { numPartitions => + withSQLConf( + CometConf.COMET_EXEC_ENABLED.key -> "false", + CometConf.COMET_EXEC_SHUFFLE_ENABLED.key -> "true", + CometConf.COMET_COLUMNAR_SHUFFLE_ENABLED.key -> "true") { + withParquetTable((0 until 5).map(i => (i, (i + 1).toLong)), "tbl") { + val df = sql("SELECT * FROM tbl") + + val shuffled1 = + df.repartitionByRange(numPartitions, $"_2").limit(2).repartition(numPartitions, $"_1") + + // 3 exchanges are expected: 1) shuffle to repartition by range, 2) shuffle to global limit, 3) hash shuffle + checkCometExchange(shuffled1, 3, false) + checkSparkAnswer(shuffled1) + + val shuffled2 = df + .repartitionByRange(numPartitions, $"_2") + .limit(2) + .repartition(numPartitions, $"_1", $"_2") + + checkCometExchange(shuffled2, 3, false) + checkSparkAnswer(shuffled2) + + val shuffled3 = df + .repartitionByRange(numPartitions, $"_2") + .limit(2) + .repartition(numPartitions, $"_2", $"_1") + + checkCometExchange(shuffled3, 3, false) + checkSparkAnswer(shuffled3) + } + } + } + } + + test("Comet columnar shuffle shuffle: different data type") { + Seq(10, 200, 201).foreach { numPartitions => + Seq(true, false).foreach { dictionaryEnabled => + withTempDir { dir => + val path = new Path(dir.toURI.toString, "test.parquet") + makeParquetFileAllTypes(path, dictionaryEnabled = dictionaryEnabled, 10000) + + Seq( + $"_1", + $"_2", + $"_3", + $"_4", + $"_5", + $"_6", + $"_7", + $"_8", + $"_9", + $"_10", + $"_11", + $"_13", + $"_14", + $"_15", + $"_16", + $"_17", + $"_18", + $"_19", + $"_20").foreach { col => + withSQLConf( + CometConf.COMET_EXEC_ENABLED.key -> "false", + CometConf.COMET_COLUMNAR_SHUFFLE_ENABLED.key -> "true", + CometConf.COMET_EXEC_SHUFFLE_ENABLED.key -> "true") { + readParquetFile(path.toString) { df => + val shuffled = df + .select($"_1") + .repartition(numPartitions, col) + val cometShuffleExecs = checkCometExchange(shuffled, 1, false) + if (numPartitions > 200) { + // For sort-based shuffle writer + cometShuffleExecs(0).shuffleDependency.shuffleHandle.getClass.getName + .contains("CometSerializedShuffleHandle") + } + checkSparkAnswer(shuffled) + } + } + } + } + } + } + } + + test("Comet native operator after Comet shuffle") { + Seq(true, false).foreach { columnarShuffle => + withSQLConf( + CometConf.COMET_EXEC_SHUFFLE_ENABLED.key -> "true", + CometConf.COMET_COLUMNAR_SHUFFLE_ENABLED.key -> columnarShuffle.toString) { + withParquetTable((0 until 5).map(i => (i, (i + 1).toLong)), "tbl") { + val df = sql("SELECT * FROM tbl") + + val shuffled1 = df + .repartition(10, $"_2") + .select($"_1", $"_1" + 1, $"_2" + 2) + .repartition(10, $"_1") + .filter($"_1" > 1) + + // 2 Comet shuffle exchanges are expected + checkCometExchange(shuffled1, 2, !columnarShuffle) + checkSparkAnswer(shuffled1) + + val shuffled2 = df + .repartitionByRange(10, $"_2") + .select($"_1", $"_1" + 1, $"_2" + 2) + .repartition(10, $"_1") + .filter($"_1" > 1) + + // 2 Comet shuffle exchanges are expected, if columnar shuffle is enabled + if (columnarShuffle) { + checkCometExchange(shuffled2, 2, !columnarShuffle) + } else { + // Because the first exchange from the bottom is range exchange which native shuffle + // doesn't support. So Comet exec operators stop before the first exchange and thus + // there is no Comet exchange. + checkCometExchange(shuffled2, 0, true) + } + checkSparkAnswer(shuffled2) + } + } + } + } + + test("Comet shuffle: single partition") { + Seq(true, false).foreach { execEnabled => + withSQLConf( + CometConf.COMET_EXEC_ENABLED.key -> execEnabled.toString, + CometConf.COMET_EXEC_SHUFFLE_ENABLED.key -> "true", + CometConf.COMET_COLUMNAR_SHUFFLE_ENABLED.key -> (!execEnabled).toString) { + withParquetTable((0 until 5).map(i => (i, (i + 1).toLong)), "tbl") { + val df = sql("SELECT * FROM tbl").sortWithinPartitions($"_1".desc) + + val shuffled = df.repartition(1) + + checkCometExchange(shuffled, 1, execEnabled) + checkSparkAnswer(shuffled) + } + } + } + } + + test("Comet shuffle metrics") { + withSQLConf( + CometConf.COMET_EXEC_ENABLED.key -> "true", + CometConf.COMET_EXEC_SHUFFLE_ENABLED.key -> "true", + CometConf.COMET_COLUMNAR_SHUFFLE_ENABLED.key -> "false") { + withParquetTable((0 until 5).map(i => (i, (i + 1).toLong)), "tbl") { + val df = sql("SELECT * FROM tbl").sortWithinPartitions($"_1".desc) + val shuffled = df.repartition(10, $"_1") + + checkCometExchange(shuffled, 1, true) + checkSparkAnswer(shuffled) + + // Materialize the shuffled data + shuffled.collect() + val metrics = find(shuffled.queryExecution.executedPlan) { + case _: CometShuffleExchangeExec => true + case _ => false + }.map(_.metrics).get + + assert(metrics.contains("shuffleRecordsWritten")) + assert(metrics("shuffleRecordsWritten").value == 5L) + } + } + } + + test("sort-based shuffle metrics") { + withSQLConf( + CometConf.COMET_EXEC_ENABLED.key -> "false", + CometConf.COMET_EXEC_SHUFFLE_ENABLED.key -> "true", + CometConf.COMET_COLUMNAR_SHUFFLE_ENABLED.key -> "true") { + withParquetTable((0 until 5).map(i => (i, (i + 1).toLong)), "tbl") { + val df = sql("SELECT * FROM tbl").sortWithinPartitions($"_1".desc) + val shuffled = df.repartition(201, $"_1") + + checkCometExchange(shuffled, 1, false) + checkSparkAnswer(shuffled) + + // Materialize the shuffled data + shuffled.collect() + val metrics = find(shuffled.queryExecution.executedPlan) { + case _: CometShuffleExchangeExec => true + case _ => false + }.map(_.metrics).get + + assert(metrics.contains("shuffleRecordsWritten")) + assert(metrics("shuffleRecordsWritten").value == 5L) + + assert(metrics.contains("shuffleBytesWritten")) + assert(metrics("shuffleBytesWritten").value > 0) + + assert(metrics.contains("shuffleWriteTime")) + assert(metrics("shuffleWriteTime").value > 0) + } + } + } +} + +class CometAsyncShuffleSuite extends CometShuffleSuiteBase { + override protected val asyncShuffleEnable: Boolean = true + + protected val adaptiveExecutionEnabled: Boolean = true +} + +class CometAsyncNonFastMergeShuffleSuite extends CometShuffleSuiteBase { + override protected val fastMergeEnabled: Boolean = false + + protected val adaptiveExecutionEnabled: Boolean = true + + protected val asyncShuffleEnable: Boolean = true +} + +class CometNonFastMergeShuffleSuite extends CometShuffleSuiteBase { + override protected val fastMergeEnabled: Boolean = false + + protected val adaptiveExecutionEnabled: Boolean = true + + protected val asyncShuffleEnable: Boolean = false +} + +class CometShuffleSuite extends CometShuffleSuiteBase { + override protected val asyncShuffleEnable: Boolean = false + + protected val adaptiveExecutionEnabled: Boolean = true + + import testImplicits._ + + // TODO: this test takes ~5mins to run, we should reduce the test time. + // Because this test takes too long, we only have it in `CometShuffleSuite`. + test("fix: Too many task completion listener of ArrowReaderIterator causes OOM") { + withSQLConf( + CometConf.COMET_EXEC_ENABLED.key -> "true", + CometConf.COMET_BATCH_SIZE.key -> "1", + CometConf.COMET_COLUMNAR_SHUFFLE_ENABLED.key -> "false", + CometConf.COMET_EXEC_SHUFFLE_ENABLED.key -> "true") { + withParquetTable((0 until 1000000).map(i => (1, (i + 1).toLong)), "tbl") { + assert( + sql("SELECT * FROM tbl").repartition(201, $"_1").count() == sql("SELECT * FROM tbl") + .count()) + } + } + } +} + +class DisableAQECometShuffleSuite extends CometShuffleSuiteBase { + override protected val asyncShuffleEnable: Boolean = false + + protected val adaptiveExecutionEnabled: Boolean = false +} + +class DisableAQECometAsyncShuffleSuite extends CometShuffleSuiteBase { + override protected val asyncShuffleEnable: Boolean = true + + protected val adaptiveExecutionEnabled: Boolean = false +} + +/** + * This suite tests the Comet shuffle encryption. Because the encryption configuration can only be + * set in SparkConf at the beginning, we need to create a separate suite for encryption. + */ +class CometShuffleEncryptionSuite extends CometShuffleSuiteBase { + override protected val adaptiveExecutionEnabled: Boolean = true + + override protected val asyncShuffleEnable: Boolean = false + + override protected val encryptionEnabled: Boolean = true +} + +class CometAsyncShuffleEncryptionSuite extends CometShuffleSuiteBase { + override protected val adaptiveExecutionEnabled: Boolean = true + + override protected val asyncShuffleEnable: Boolean = true + + override protected val encryptionEnabled: Boolean = true +} + +class DisableAQECometShuffleEncryptionSuite extends CometShuffleSuiteBase { + override protected val adaptiveExecutionEnabled: Boolean = false + + override protected val asyncShuffleEnable: Boolean = false + + override protected val encryptionEnabled: Boolean = true +} + +class DisableAQECometAsyncShuffleEncryptionSuite extends CometShuffleSuiteBase { + override protected val adaptiveExecutionEnabled: Boolean = false + + override protected val asyncShuffleEnable: Boolean = true + + override protected val encryptionEnabled: Boolean = true +} + +class CometShuffleManagerSuite extends CometTestBase { + + test("should not bypass merge sort if executor cores are too high") { + withSQLConf(CometConf.COMET_EXEC_SHUFFLE_ASYNC_MAX_THREAD_NUM.key -> "100") { + val conf = new SparkConf() + conf.set("spark.executor.cores", "1") + + val rdd = spark.emptyDataFrame.rdd.map(x => (0, x)) + + val dependency = new CometShuffleDependency[Int, Row, Row]( + _rdd = rdd, + serializer = null, + shuffleWriterProcessor = null, + partitioner = new Partitioner { + override def numPartitions: Int = 50 + override def getPartition(key: Any): Int = key.asInstanceOf[Int] + }) + + assert(CometShuffleManager.shouldBypassMergeSort(conf, dependency)) + + conf.set("spark.executor.cores", "10") + assert(!CometShuffleManager.shouldBypassMergeSort(conf, dependency)) + } + } +} diff --git a/spark/src/test/scala/org/apache/spark/CometPluginsSuite.scala b/spark/src/test/scala/org/apache/spark/CometPluginsSuite.scala index a95674118..2263b2f99 100644 --- a/spark/src/test/scala/org/apache/spark/CometPluginsSuite.scala +++ b/spark/src/test/scala/org/apache/spark/CometPluginsSuite.scala @@ -55,7 +55,7 @@ class CometPluginsDefaultSuite extends CometTestBase { conf.set("spark.executor.memoryOverheadFactor", "0.5") conf.set("spark.plugins", "org.apache.spark.CometPlugin") conf.set("spark.comet.enabled", "true") - conf.set("spark.comet.exec.enabled", "true") + conf.set("spark.comet.exec.shuffle.enabled", "true") conf } diff --git a/spark/src/test/scala/org/apache/spark/sql/CometTPCDSQuerySuite.scala b/spark/src/test/scala/org/apache/spark/sql/CometTPCDSQuerySuite.scala index f8213734b..92b4aa7d1 100644 --- a/spark/src/test/scala/org/apache/spark/sql/CometTPCDSQuerySuite.scala +++ b/spark/src/test/scala/org/apache/spark/sql/CometTPCDSQuerySuite.scala @@ -144,11 +144,15 @@ class CometTPCDSQuerySuite override def sparkConf: SparkConf = { val conf = super.sparkConf conf.set("spark.sql.extensions", "org.apache.comet.CometSparkSessionExtensions") + conf.set( + "spark.shuffle.manager", + "org.apache.spark.sql.comet.execution.shuffle.CometShuffleManager") conf.set(CometConf.COMET_ENABLED.key, "true") conf.set(CometConf.COMET_EXEC_ENABLED.key, "true") conf.set(CometConf.COMET_MEMORY_OVERHEAD.key, "2g") conf.set(CometConf.COMET_EXEC_ALL_OPERATOR_ENABLED.key, "true") conf.set(CometConf.COMET_EXEC_ALL_EXPR_ENABLED.key, "true") + conf.set(CometConf.COMET_EXEC_SHUFFLE_ENABLED.key, "true") conf } diff --git a/spark/src/test/scala/org/apache/spark/sql/CometTPCHQuerySuite.scala b/spark/src/test/scala/org/apache/spark/sql/CometTPCHQuerySuite.scala index 5f1ba6296..372a4ccf7 100644 --- a/spark/src/test/scala/org/apache/spark/sql/CometTPCHQuerySuite.scala +++ b/spark/src/test/scala/org/apache/spark/sql/CometTPCHQuerySuite.scala @@ -82,11 +82,15 @@ class CometTPCHQuerySuite extends QueryTest with CometTPCBase with SQLQueryTestH val conf = super.sparkConf conf.set(SQLConf.SHUFFLE_PARTITIONS.key, "1") conf.set("spark.sql.extensions", "org.apache.comet.CometSparkSessionExtensions") + conf.set( + "spark.shuffle.manager", + "org.apache.spark.sql.comet.execution.shuffle.CometShuffleManager") conf.set(CometConf.COMET_ENABLED.key, "true") conf.set(CometConf.COMET_EXEC_ENABLED.key, "true") conf.set(CometConf.COMET_MEMORY_OVERHEAD.key, "2g") conf.set(CometConf.COMET_EXEC_ALL_OPERATOR_ENABLED.key, "true") conf.set(CometConf.COMET_EXEC_ALL_EXPR_ENABLED.key, "true") + conf.set(CometConf.COMET_EXEC_SHUFFLE_ENABLED.key, "true") } protected override def createSparkSession: TestSparkSession = { diff --git a/spark/src/test/scala/org/apache/spark/sql/CometTestBase.scala b/spark/src/test/scala/org/apache/spark/sql/CometTestBase.scala index 08c6cf419..4f2838cfb 100644 --- a/spark/src/test/scala/org/apache/spark/sql/CometTestBase.scala +++ b/spark/src/test/scala/org/apache/spark/sql/CometTestBase.scala @@ -36,6 +36,7 @@ import org.apache.parquet.hadoop.example.ExampleParquetWriter import org.apache.parquet.schema.{MessageType, MessageTypeParser} import org.apache.spark._ import org.apache.spark.sql.comet.{CometBatchScanExec, CometExec, CometScanExec, CometScanWrapper, CometSinkPlaceHolder} +import org.apache.spark.sql.comet.execution.shuffle.{CometColumnarShuffle, CometNativeShuffle, CometShuffleExchangeExec} import org.apache.spark.sql.execution.{ColumnarToRowExec, InputAdapter, SparkPlan, WholeStageCodegenExec} import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper import org.apache.spark.sql.internal._ @@ -57,10 +58,14 @@ abstract class CometTestBase with AdaptiveSparkPlanHelper { import testImplicits._ + protected val shuffleManager: String = + "org.apache.spark.sql.comet.execution.shuffle.CometShuffleManager" + protected def sparkConf: SparkConf = { val conf = new SparkConf() conf.set("spark.hadoop.fs.file.impl", classOf[DebugFilesystem].getName) conf.set(SQLConf.SHUFFLE_PARTITIONS, 10) // reduce parallelism in tests + conf.set("spark.shuffle.manager", shuffleManager) conf.set(CometConf.COMET_MEMORY_OVERHEAD.key, "2g") conf } @@ -73,6 +78,7 @@ abstract class CometTestBase CometConf.COMET_EXEC_ENABLED.key -> "true", CometConf.COMET_EXEC_ALL_OPERATOR_ENABLED.key -> "true", CometConf.COMET_EXEC_ALL_EXPR_ENABLED.key -> "true", + CometConf.COMET_COLUMNAR_SHUFFLE_MEMORY_SIZE.key -> "2g", SQLConf.ANSI_ENABLED.key -> "false") { testFun } @@ -150,7 +156,7 @@ abstract class CometTestBase plan.foreach { case _: CometScanExec | _: CometBatchScanExec => true case _: CometSinkPlaceHolder | _: CometScanWrapper => false - case _: CometExec => true + case _: CometExec | _: CometShuffleExchangeExec => true case _: WholeStageCodegenExec | _: ColumnarToRowExec | _: InputAdapter => true case op => if (excludedClasses.exists(c => c.isAssignableFrom(op.getClass))) { @@ -594,4 +600,33 @@ abstract class CometTestBase def stripRandomPlanParts(plan: String): String = { plan.replaceFirst("file:.*,", "").replaceAll(raw"#\d+", "") } + + protected def checkCometExchange( + df: DataFrame, + cometExchangeNum: Int, + native: Boolean): Seq[CometShuffleExchangeExec] = { + if (CometConf.COMET_EXEC_SHUFFLE_ENABLED.get()) { + val sparkPlan = stripAQEPlan(df.queryExecution.executedPlan) + + val cometShuffleExecs = sparkPlan.collect { case b: CometShuffleExchangeExec => b } + assert( + cometShuffleExecs.length == cometExchangeNum, + s"$sparkPlan has ${cometShuffleExecs.length} " + + s" CometShuffleExchangeExec node which doesn't match the expected: $cometExchangeNum") + + if (native) { + cometShuffleExecs.foreach { b => + assert(b.shuffleType == CometNativeShuffle) + } + } else { + cometShuffleExecs.foreach { b => + assert(b.shuffleType == CometColumnarShuffle) + } + } + + cometShuffleExecs + } else { + Seq.empty + } + } } diff --git a/spark/src/test/scala/org/apache/spark/sql/benchmark/CometExecBenchmark.scala b/spark/src/test/scala/org/apache/spark/sql/benchmark/CometExecBenchmark.scala index 8f3f2438d..d6020ac69 100644 --- a/spark/src/test/scala/org/apache/spark/sql/benchmark/CometExecBenchmark.scala +++ b/spark/src/test/scala/org/apache/spark/sql/benchmark/CometExecBenchmark.scala @@ -40,6 +40,11 @@ object CometExecBenchmark extends CometBenchmarkBase { .setIfMissing("spark.driver.memory", "3g") .setIfMissing("spark.executor.memory", "3g") .set("spark.executor.memoryOverhead", "10g") + .set( + "spark.shuffle.manager", + "org.apache.spark.sql.comet.execution.shuffle.CometShuffleManager") + .set("spark.comet.columnar.shuffle.async.thread.num", "7") + .set("spark.comet.columnar.shuffle.spill.threshold", "30000") val sparkSession = SparkSession.builder .config(conf) @@ -52,6 +57,7 @@ object CometExecBenchmark extends CometBenchmarkBase { sparkSession.conf.set(CometConf.COMET_ENABLED.key, "false") sparkSession.conf.set(CometConf.COMET_EXEC_ENABLED.key, "false") sparkSession.conf.set(CometConf.COMET_MEMORY_OVERHEAD.key, "10g") + sparkSession.conf.set(CometConf.COMET_COLUMNAR_SHUFFLE_MEMORY_SIZE.key, "10g") // TODO: support dictionary encoding in vectorized execution sparkSession.conf.set("parquet.enable.dictionary", "false") sparkSession.conf.set("spark.sql.shuffle.partitions", "2") @@ -123,7 +129,9 @@ object CometExecBenchmark extends CometBenchmarkBase { withSQLConf( CometConf.COMET_ENABLED.key -> "true", CometConf.COMET_EXEC_ENABLED.key -> "true", - CometConf.COMET_EXEC_ALL_OPERATOR_ENABLED.key -> "true") { + CometConf.COMET_EXEC_ALL_OPERATOR_ENABLED.key -> "true", + CometConf.COMET_EXEC_SHUFFLE_ENABLED.key -> "true", + CometConf.COMET_COLUMNAR_SHUFFLE_ENABLED.key -> "true") { spark.sql( "SELECT (SELECT max(col1) AS parquetV1Table FROM parquetV1Table) AS a, " + "col2, col3 FROM parquetV1Table") diff --git a/spark/src/test/scala/org/apache/spark/sql/benchmark/CometShuffleBenchmark.scala b/spark/src/test/scala/org/apache/spark/sql/benchmark/CometShuffleBenchmark.scala new file mode 100644 index 000000000..865572811 --- /dev/null +++ b/spark/src/test/scala/org/apache/spark/sql/benchmark/CometShuffleBenchmark.scala @@ -0,0 +1,609 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.spark.sql.benchmark + +import org.apache.spark.SparkConf +import org.apache.spark.benchmark.Benchmark +import org.apache.spark.sql.{Column, SparkSession} +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.types._ + +import org.apache.comet.CometConf +import org.apache.comet.CometSparkSessionExtensions + +/** + * Benchmark to measure Comet shuffle performance. To run this benchmark: + * `SPARK_GENERATE_BENCHMARK_FILES=1 make + * benchmark-org.apache.spark.sql.benchmark.CometShuffleBenchmark` Results will be written to + * "spark/benchmarks/CometShuffleBenchmark-**results.txt". + */ +object CometShuffleBenchmark extends CometBenchmarkBase { + override def getSparkSession: SparkSession = { + val conf = new SparkConf() + .setAppName("CometShuffleBenchmark") + // Since `spark.master` always exists, overrides this value + .set("spark.master", "local[5]") + .setIfMissing("spark.driver.memory", "3g") + .setIfMissing("spark.executor.memory", "3g") + .set("spark.executor.memoryOverhead", "10g") + .set( + "spark.shuffle.manager", + "org.apache.spark.sql.comet.execution.shuffle.CometShuffleManager") + .set("spark.comet.columnar.shuffle.async.thread.num", "7") + .set("spark.comet.columnar.shuffle.spill.threshold", "30000") + .set("spark.comet.memoryOverhead", "10g") + + val sparkSession = SparkSession.builder + .config(conf) + .withExtensions(new CometSparkSessionExtensions) + .getOrCreate() + + // Set default configs. Individual cases will change them if necessary. + sparkSession.conf.set(SQLConf.PARQUET_VECTORIZED_READER_ENABLED.key, "true") + sparkSession.conf.set(SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key, "true") + sparkSession.conf.set(CometConf.COMET_ENABLED.key, "false") + sparkSession.conf.set(CometConf.COMET_EXEC_ENABLED.key, "false") + sparkSession.conf.set(CometConf.COMET_COLUMNAR_SHUFFLE_MEMORY_SIZE.key, "10g") + // TODO: support dictionary encoding in vectorized execution + sparkSession.conf.set("parquet.enable.dictionary", "false") + + sparkSession + } + + def shuffleArrayBenchmark(values: Int, dataType: DataType, partitionNum: Int): Unit = { + val benchmark = + new Benchmark( + s"SQL ${dataType.sql} shuffle on array ($partitionNum Partition)", + values, + output = output) + + withTempPath { dir => + withTempTable("parquetV1Table") { + prepareTable(dir, spark.sql(s"SELECT CAST(1 AS ${dataType.sql}) AS c1 FROM $tbl")) + + benchmark.addCase("SQL Parquet - Spark") { _ => + spark + .sql(s"SELECT ARRAY_REPEAT(CAST(1 AS ${dataType.sql}), 10) AS c1 FROM parquetV1Table") + .repartition(partitionNum, Column("c1")) + .noop() + } + + benchmark.addCase("SQL Parquet - Comet (Spark Shuffle)") { _ => + withSQLConf( + CometConf.COMET_ENABLED.key -> "true", + CometConf.COMET_EXEC_ENABLED.key -> "true", + CometConf.COMET_EXEC_ALL_OPERATOR_ENABLED.key -> "true", + CometConf.COMET_EXEC_SHUFFLE_ENABLED.key -> "false") { + spark + .sql( + s"SELECT ARRAY_REPEAT(CAST(1 AS ${dataType.sql}), 10) AS c1 FROM parquetV1Table") + .repartition(partitionNum, Column("c1")) + .noop() + } + } + + benchmark.addCase("SQL Parquet - Comet (Comet Arrow Shuffle)") { _ => + withSQLConf( + CometConf.COMET_ENABLED.key -> "true", + CometConf.COMET_EXEC_ENABLED.key -> "true", + CometConf.COMET_EXEC_SHUFFLE_ENABLED.key -> "true", + CometConf.COMET_EXEC_ALL_OPERATOR_ENABLED.key -> "true", + CometConf.COMET_SHUFFLE_PREFER_DICTIONARY_RATIO.key -> "1.0", + CometConf.COMET_COLUMNAR_SHUFFLE_ENABLED.key -> "true", + CometConf.COMET_COLUMNAR_SHUFFLE_ASYNC_ENABLED.key -> "false") { + spark + .sql( + s"SELECT ARRAY_REPEAT(CAST(1 AS ${dataType.sql}), 10) AS c1 FROM parquetV1Table") + .repartition(partitionNum, Column("c1")) + .noop() + } + } + + benchmark.run() + } + } + } + + def shuffleStructBenchmark(values: Int, dataType: DataType, partitionNum: Int): Unit = { + val benchmark = + new Benchmark( + s"SQL ${dataType.sql} shuffle on struct ($partitionNum Partition)", + values, + output = output) + + withTempPath { dir => + withTempTable("parquetV1Table") { + prepareTable(dir, spark.sql(s"SELECT CAST(1 AS ${dataType.sql}) AS c1 FROM $tbl")) + + benchmark.addCase("SQL Parquet - Spark") { _ => + spark + .sql( + s"SELECT STRUCT(CAST(c1 AS ${dataType.sql})," + + s"CAST(c1 AS ${dataType.sql}), " + + s"CAST(c1 AS ${dataType.sql})) AS c1 FROM parquetV1Table") + .repartition(partitionNum, Column("c1")) + .noop() + } + + benchmark.addCase("SQL Parquet - Comet (Spark Shuffle)") { _ => + withSQLConf( + CometConf.COMET_ENABLED.key -> "true", + CometConf.COMET_EXEC_ENABLED.key -> "true", + CometConf.COMET_EXEC_ALL_OPERATOR_ENABLED.key -> "true", + CometConf.COMET_EXEC_SHUFFLE_ENABLED.key -> "false") { + spark + .sql( + s"SELECT STRUCT(CAST(c1 AS ${dataType.sql})," + + s"CAST(c1 AS ${dataType.sql}), " + + s"CAST(c1 AS ${dataType.sql})) AS c1 FROM parquetV1Table") + .repartition(partitionNum, Column("c1")) + .noop() + } + } + + benchmark.addCase("SQL Parquet - Comet (Comet Arrow Shuffle)") { _ => + withSQLConf( + CometConf.COMET_ENABLED.key -> "true", + CometConf.COMET_EXEC_ENABLED.key -> "true", + CometConf.COMET_EXEC_SHUFFLE_ENABLED.key -> "true", + CometConf.COMET_EXEC_ALL_OPERATOR_ENABLED.key -> "true", + CometConf.COMET_SHUFFLE_PREFER_DICTIONARY_RATIO.key -> "1.0", + CometConf.COMET_COLUMNAR_SHUFFLE_ENABLED.key -> "true", + CometConf.COMET_COLUMNAR_SHUFFLE_ASYNC_ENABLED.key -> "false") { + spark + .sql( + s"SELECT STRUCT(CAST(c1 AS ${dataType.sql})," + + s"CAST(c1 AS ${dataType.sql}), " + + s"CAST(c1 AS ${dataType.sql})) AS c1 FROM parquetV1Table") + .repartition(partitionNum, Column("c1")) + .noop() + } + } + + benchmark.run() + } + } + } + + def shuffleDictionaryBenchmark(values: Int, dataType: DataType, partitionNum: Int): Unit = { + val benchmark = + new Benchmark( + s"SQL ${dataType.sql} Dictionary Shuffle($partitionNum Partition)", + values, + output = output) + + withTempPath { dir => + withTempTable("parquetV1Table") { + prepareTable( + dir, + spark.sql(s"SELECT REPEAT(CAST(1 AS ${dataType.sql}), 100) AS c1 FROM $tbl")) + + benchmark.addCase("SQL Parquet - Spark") { _ => + spark + .sql("select c1 from parquetV1Table") + .repartition(partitionNum, Column("c1")) + .noop() + } + + benchmark.addCase("SQL Parquet - Comet (Spark Shuffle)") { _ => + withSQLConf( + CometConf.COMET_ENABLED.key -> "true", + CometConf.COMET_EXEC_ENABLED.key -> "true", + CometConf.COMET_EXEC_ALL_OPERATOR_ENABLED.key -> "true", + CometConf.COMET_EXEC_SHUFFLE_ENABLED.key -> "false") { + spark + .sql("select c1 from parquetV1Table") + .repartition(partitionNum, Column("c1")) + .noop() + } + } + + benchmark.addCase("SQL Parquet - Comet (Comet Arrow Shuffle)") { _ => + withSQLConf( + CometConf.COMET_ENABLED.key -> "true", + CometConf.COMET_EXEC_ENABLED.key -> "true", + CometConf.COMET_EXEC_SHUFFLE_ENABLED.key -> "true", + CometConf.COMET_EXEC_ALL_OPERATOR_ENABLED.key -> "true", + CometConf.COMET_SHUFFLE_PREFER_DICTIONARY_RATIO.key -> "1.0", + CometConf.COMET_COLUMNAR_SHUFFLE_ENABLED.key -> "true", + CometConf.COMET_COLUMNAR_SHUFFLE_ASYNC_ENABLED.key -> "false") { + spark + .sql("select c1 from parquetV1Table") + .repartition(partitionNum, Column("c1")) + .noop() + } + } + + benchmark.addCase("SQL Parquet - Comet (Comet Arrow Shuffle + Prefer Dictionary)") { _ => + withSQLConf( + CometConf.COMET_ENABLED.key -> "true", + CometConf.COMET_EXEC_ENABLED.key -> "true", + CometConf.COMET_EXEC_SHUFFLE_ENABLED.key -> "true", + CometConf.COMET_EXEC_ALL_OPERATOR_ENABLED.key -> "true", + CometConf.COMET_SHUFFLE_PREFER_DICTIONARY_RATIO.key -> "2.0", + CometConf.COMET_COLUMNAR_SHUFFLE_ENABLED.key -> "true", + CometConf.COMET_COLUMNAR_SHUFFLE_ASYNC_ENABLED.key -> "false") { + spark + .sql("select c1 from parquetV1Table") + .repartition(partitionNum, Column("c1")) + .noop() + } + } + + benchmark.addCase("SQL Parquet - Comet (Comet Arrow Shuffle + Fallback to string)") { _ => + withSQLConf( + CometConf.COMET_ENABLED.key -> "true", + CometConf.COMET_EXEC_ENABLED.key -> "true", + CometConf.COMET_EXEC_SHUFFLE_ENABLED.key -> "true", + CometConf.COMET_EXEC_ALL_OPERATOR_ENABLED.key -> "true", + CometConf.COMET_SHUFFLE_PREFER_DICTIONARY_RATIO.key -> "1000000000.0", + CometConf.COMET_COLUMNAR_SHUFFLE_ENABLED.key -> "true", + CometConf.COMET_COLUMNAR_SHUFFLE_ASYNC_ENABLED.key -> "false") { + spark + .sql("select c1 from parquetV1Table") + .repartition(partitionNum, Column("c1")) + .noop() + } + } + + benchmark.run() + } + } + } + + def shuffleBenchmark( + values: Int, + dataType: DataType, + random: Boolean, + partitionNum: Int): Unit = { + val randomTitle = if (random) { + "With Random" + } else { + "" + } + val benchmark = + new Benchmark( + s"SQL Single ${dataType.sql} Shuffle($partitionNum Partition) $randomTitle", + values, + output = output) + + withTempPath { dir => + withTempTable("parquetV1Table") { + if (random) { + prepareTable( + dir, + spark.sql( + s"SELECT CAST(CAST(RAND(1) * 100 AS INTEGER) AS ${dataType.sql}) AS c1 FROM $tbl")) + } else { + prepareTable(dir, spark.sql(s"SELECT CAST(1 AS ${dataType.sql}) AS c1 FROM $tbl")) + } + + benchmark.addCase("SQL Parquet - Spark") { _ => + spark + .sql("select c1 from parquetV1Table") + .repartition(partitionNum, Column("c1")) + .noop() + } + + benchmark.addCase("SQL Parquet - Comet (Spark Shuffle)") { _ => + withSQLConf( + CometConf.COMET_ENABLED.key -> "true", + CometConf.COMET_EXEC_ENABLED.key -> "true", + CometConf.COMET_EXEC_ALL_OPERATOR_ENABLED.key -> "true", + CometConf.COMET_EXEC_SHUFFLE_ENABLED.key -> "false") { + spark + .sql("select c1 from parquetV1Table") + .repartition(partitionNum, Column("c1")) + .noop() + } + } + + benchmark.addCase("SQL Parquet - Comet (Comet Arrow Shuffle)") { _ => + withSQLConf( + CometConf.COMET_ENABLED.key -> "true", + CometConf.COMET_EXEC_ENABLED.key -> "true", + CometConf.COMET_EXEC_SHUFFLE_ENABLED.key -> "true", + CometConf.COMET_EXEC_ALL_OPERATOR_ENABLED.key -> "true", + CometConf.COMET_COLUMNAR_SHUFFLE_ENABLED.key -> "true", + CometConf.COMET_COLUMNAR_SHUFFLE_ASYNC_ENABLED.key -> "false") { + spark + .sql("select c1 from parquetV1Table") + .repartition(partitionNum, Column("c1")) + .noop() + } + } + + benchmark.addCase("SQL Parquet - Comet (Comet Async Arrow Shuffle)") { _ => + withSQLConf( + CometConf.COMET_ENABLED.key -> "true", + CometConf.COMET_EXEC_ENABLED.key -> "true", + CometConf.COMET_EXEC_SHUFFLE_ENABLED.key -> "true", + CometConf.COMET_EXEC_ALL_OPERATOR_ENABLED.key -> "true", + CometConf.COMET_COLUMNAR_SHUFFLE_ENABLED.key -> "true", + CometConf.COMET_COLUMNAR_SHUFFLE_ASYNC_ENABLED.key -> "true") { + spark + .sql("select c1 from parquetV1Table") + .repartition(partitionNum, Column("c1")) + .noop() + } + } + + benchmark.addCase("SQL Parquet - Comet (Comet Shuffle)") { _ => + withSQLConf( + CometConf.COMET_ENABLED.key -> "true", + CometConf.COMET_EXEC_ENABLED.key -> "true", + CometConf.COMET_EXEC_ALL_OPERATOR_ENABLED.key -> "true", + CometConf.COMET_EXEC_SHUFFLE_ENABLED.key -> "true") { + spark + .sql("select c1 from parquetV1Table") + .repartition(partitionNum, Column("c1")) + .noop() + } + } + + benchmark.run() + } + } + } + + def shuffleWideBenchmark( + values: Int, + dataType: DataType, + width: Int, + partitionNum: Int): Unit = { + val benchmark = + new Benchmark( + s"SQL Wide ($width cols) ${dataType.sql} Shuffle($partitionNum Partition)", + values, + output = output) + + val projection = (1 to width) + .map(i => s"CAST(CAST(RAND(1) * 100 AS INTEGER) AS ${dataType.sql}) AS c$i") + .mkString(", ") + val columns = (1 to width).map(i => s"c$i").mkString(", ") + + withTempPath { dir => + withTempTable("parquetV1Table") { + prepareTable(dir, spark.sql(s"SELECT $projection FROM $tbl")) + + benchmark.addCase("SQL Parquet - Spark") { _ => + spark + .sql(s"select $columns from parquetV1Table") + .repartition(partitionNum, Column("c1")) + .noop() + } + + benchmark.addCase("SQL Parquet - Comet (Spark Shuffle)") { _ => + withSQLConf( + CometConf.COMET_ENABLED.key -> "true", + CometConf.COMET_EXEC_ENABLED.key -> "true", + CometConf.COMET_EXEC_ALL_OPERATOR_ENABLED.key -> "true", + CometConf.COMET_EXEC_SHUFFLE_ENABLED.key -> "false") { + spark + .sql(s"select $columns from parquetV1Table") + .repartition(partitionNum, Column("c1")) + .noop() + } + } + + benchmark.addCase("SQL Parquet - Comet (Comet Arrow Shuffle)") { _ => + withSQLConf( + CometConf.COMET_ENABLED.key -> "true", + CometConf.COMET_EXEC_ENABLED.key -> "true", + CometConf.COMET_EXEC_SHUFFLE_ENABLED.key -> "true", + CometConf.COMET_EXEC_ALL_OPERATOR_ENABLED.key -> "true", + CometConf.COMET_COLUMNAR_SHUFFLE_ENABLED.key -> "true") { + spark + .sql(s"select $columns from parquetV1Table") + .repartition(partitionNum, Column("c1")) + .noop() + } + } + + benchmark.addCase("SQL Parquet - Comet (Comet Shuffle)") { _ => + withSQLConf( + CometConf.COMET_ENABLED.key -> "true", + CometConf.COMET_EXEC_ENABLED.key -> "true", + CometConf.COMET_EXEC_ALL_OPERATOR_ENABLED.key -> "true", + CometConf.COMET_EXEC_SHUFFLE_ENABLED.key -> "true") { + spark + .sql(s"select $columns from parquetV1Table") + .repartition(partitionNum, Column("c1")) + .noop() + } + } + + benchmark.run() + } + } + } + + override def runCometBenchmark(mainArgs: Array[String]): Unit = { + runBenchmarkWithTable("Shuffle on array", 1024 * 1024 * 1) { v => + Seq( + BooleanType, + ByteType, + ShortType, + IntegerType, + LongType, + FloatType, + DoubleType, + StringType, + DecimalType(10, 0)).foreach { dataType => + Seq(5, 201).foreach { partitionNum => + shuffleArrayBenchmark(v, dataType, partitionNum) + } + } + } + + runBenchmarkWithTable("Shuffle on struct", 1024 * 1024 * 100) { v => + Seq( + BooleanType, + ByteType, + ShortType, + IntegerType, + LongType, + FloatType, + DoubleType, + StringType, + DecimalType(10, 0)).foreach { dataType => + Seq(5, 201).foreach { partitionNum => + shuffleStructBenchmark(v, dataType, partitionNum) + } + } + } + + runBenchmarkWithTable("Dictionary Shuffle", 1024 * 1024 * 100) { v => + Seq(BinaryType, StringType).foreach { dataType => + Seq(5, 201).foreach { partitionNum => + shuffleDictionaryBenchmark(v, dataType, partitionNum) + } + } + } + + runBenchmarkWithTable("Shuffle", 1024 * 1024 * 10) { v => + Seq( + BooleanType, + ByteType, + ShortType, + IntegerType, + LongType, + FloatType, + DoubleType, + StringType, + DecimalType(10, 0)) + .foreach { dataType => + shuffleBenchmark(v, dataType, false, 5) + } + } + + runBenchmarkWithTable("Shuffle", 1024 * 1024 * 10) { v => + Seq( + BooleanType, + ByteType, + ShortType, + IntegerType, + LongType, + FloatType, + DoubleType, + StringType, + DecimalType(10, 0)) + .foreach { dataType => + shuffleBenchmark(v, dataType, false, 201) + } + } + + runBenchmarkWithTable("Shuffle with random values", 1024 * 1024 * 10) { v => + Seq( + BooleanType, + ByteType, + ShortType, + IntegerType, + LongType, + FloatType, + DoubleType, + StringType, + DecimalType(10, 0)) + .foreach { dataType => + shuffleBenchmark(v, dataType, true, 5) + } + } + + runBenchmarkWithTable("Shuffle with random values", 1024 * 1024 * 10) { v => + Seq( + BooleanType, + ByteType, + ShortType, + IntegerType, + LongType, + FloatType, + DoubleType, + StringType, + DecimalType(10, 0)) + .foreach { dataType => + shuffleBenchmark(v, dataType, true, 201) + } + } + + runBenchmarkWithTable("Wide Shuffle (10 cols)", 1024 * 1024 * 10) { v => + Seq( + BooleanType, + ByteType, + ShortType, + IntegerType, + LongType, + FloatType, + DoubleType, + StringType, + DecimalType(10, 0)) + .foreach { dataType => + shuffleWideBenchmark(v, dataType, 10, 5) + } + } + + runBenchmarkWithTable("Wide Shuffle (20 cols)", 1024 * 1024 * 10) { v => + Seq( + BooleanType, + ByteType, + ShortType, + IntegerType, + LongType, + FloatType, + DoubleType, + StringType, + DecimalType(10, 0)) + .foreach { dataType => + shuffleWideBenchmark(v, dataType, 20, 5) + } + } + + runBenchmarkWithTable("Wide Shuffle (10 cols)", 1024 * 1024 * 10) { v => + Seq( + BooleanType, + ByteType, + ShortType, + IntegerType, + LongType, + FloatType, + DoubleType, + StringType, + DecimalType(10, 0)) + .foreach { dataType => + shuffleWideBenchmark(v, dataType, 10, 201) + } + } + + runBenchmarkWithTable("Wide Shuffle (20 cols)", 1024 * 1024 * 10) { v => + Seq( + BooleanType, + ByteType, + ShortType, + IntegerType, + LongType, + FloatType, + DoubleType, + StringType, + DecimalType(10, 0)) + .foreach { dataType => + shuffleWideBenchmark(v, dataType, 20, 201) + } + } + } +} diff --git a/spark/src/test/scala/org/apache/spark/sql/comet/CometPlanStabilitySuite.scala b/spark/src/test/scala/org/apache/spark/sql/comet/CometPlanStabilitySuite.scala index 812d56738..c11387f6b 100644 --- a/spark/src/test/scala/org/apache/spark/sql/comet/CometPlanStabilitySuite.scala +++ b/spark/src/test/scala/org/apache/spark/sql/comet/CometPlanStabilitySuite.scala @@ -257,6 +257,7 @@ trait CometPlanStabilitySuite extends DisableAdaptiveExecutionSuite with TPCDSBa withSQLConf( CometConf.COMET_ENABLED.key -> "true", CometConf.COMET_EXEC_ENABLED.key -> "true", + CometConf.COMET_EXEC_SHUFFLE_ENABLED.key -> "true", CometConf.COMET_EXEC_ALL_OPERATOR_ENABLED.key -> "true", CometConf.COMET_EXEC_ALL_EXPR_ENABLED.key -> "true", "spark.sql.readSideCharPadding" -> "false", @@ -278,11 +279,15 @@ trait CometPlanStabilitySuite extends DisableAdaptiveExecutionSuite with TPCDSBa protected override def createSparkSession: TestSparkSession = { val conf = super.sparkConf conf.set("spark.sql.extensions", "org.apache.comet.CometSparkSessionExtensions") + conf.set( + "spark.shuffle.manager", + "org.apache.spark.sql.comet.execution.shuffle.CometShuffleManager") conf.set(CometConf.COMET_ENABLED.key, "true") conf.set(CometConf.COMET_EXEC_ENABLED.key, "true") conf.set(CometConf.COMET_MEMORY_OVERHEAD.key, "1g") conf.set(CometConf.COMET_EXEC_ALL_OPERATOR_ENABLED.key, "true") conf.set(CometConf.COMET_EXEC_ALL_EXPR_ENABLED.key, "true") + conf.set(CometConf.COMET_EXEC_SHUFFLE_ENABLED.key, "true") new TestSparkSession(new SparkContext("local[1]", this.getClass.getCanonicalName, conf)) }