diff --git a/common/src/main/scala/org/apache/comet/CometConf.scala b/common/src/main/scala/org/apache/comet/CometConf.scala index 8815ac4eb..3bd1fc688 100644 --- a/common/src/main/scala/org/apache/comet/CometConf.scala +++ b/common/src/main/scala/org/apache/comet/CometConf.scala @@ -278,7 +278,7 @@ object CometConf extends ShimCometConf { "The codec of Comet native shuffle used to compress shuffle data. Only zstd is supported. " + "Compression can be disabled by setting spark.shuffle.compress=false.") .stringConf - .checkValues(Set("zstd")) + .checkValues(Set("zstd", "lz4")) .createWithDefault("zstd") val COMET_EXEC_SHUFFLE_COMPRESSION_LEVEL: ConfigEntry[Int] = diff --git a/docs/source/user-guide/tuning.md b/docs/source/user-guide/tuning.md index e04e750b4..ab0a44a27 100644 --- a/docs/source/user-guide/tuning.md +++ b/docs/source/user-guide/tuning.md @@ -105,6 +105,10 @@ then any shuffle operations that cannot be supported in this mode will fall back ### Shuffle Compression +Comet supports lz4 and zstd compression. Lz4 compression is typically faster and zstd typically has a +better compression ratio. Zstd compression level is configurable and can be set with the configuration setting +`spark.comet.exec.shuffle.compression.level`. + By default, Spark compresses shuffle files using LZ4 compression. Comet overrides this behavior with ZSTD compression. Compression can be disabled by setting `spark.shuffle.compress=false`, which may result in faster shuffle times in certain environments, such as single-node setups with fast NVMe drives, at the expense of increased disk space usage. diff --git a/native/Cargo.lock b/native/Cargo.lock index 538c40ee2..dfed0cc70 100644 --- a/native/Cargo.lock +++ b/native/Cargo.lock @@ -920,7 +920,7 @@ dependencies = [ "lazy_static", "log", "log4rs", - "lz4", + "lz4_flex", "mimalloc", "num", "once_cell", @@ -2111,25 +2111,6 @@ dependencies = [ "winapi", ] -[[package]] -name = "lz4" -version = "1.28.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4d1febb2b4a79ddd1980eede06a8f7902197960aa0383ffcfdd62fe723036725" -dependencies = [ - "lz4-sys", -] - -[[package]] -name = "lz4-sys" -version = "1.11.1+lz4-1.10.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6bd8c0d6c6ed0cd30b3652886bb8711dc4bb01d637a68105a3d5158039b418e6" -dependencies = [ - "cc", - "libc", -] - [[package]] name = "lz4_flex" version = "0.11.3" diff --git a/native/core/Cargo.toml b/native/core/Cargo.toml index 489da46d4..197bf4318 100644 --- a/native/core/Cargo.toml +++ b/native/core/Cargo.toml @@ -55,7 +55,7 @@ jni = "0.21" snap = "1.1" brotli = "3.3" flate2 = "1.0" -lz4 = "1.24" +lz4_flex = "0.11.3" zstd = "0.11" rand = { workspace = true} num = { workspace = true } diff --git a/native/core/src/execution/jni_api.rs b/native/core/src/execution/jni_api.rs index eb73675b5..dd5b835ad 100644 --- a/native/core/src/execution/jni_api.rs +++ b/native/core/src/execution/jni_api.rs @@ -60,6 +60,7 @@ use jni::{ use tokio::runtime::Runtime; use crate::execution::operators::ScanExec; +use crate::execution::shuffle::{read_ipc_compressed_lz4, read_ipc_compressed_zstd}; use crate::execution::spark_plan::SparkPlan; use log::info; @@ -544,3 +545,45 @@ pub extern "system" fn Java_org_apache_comet_Native_sortRowPartitionsNative( Ok(()) }) } + +#[no_mangle] +/// Used by Comet native shuffle reader +pub extern "system" fn Java_org_apache_comet_Native_decodeShuffleBlockLz4( + e: JNIEnv, + _class: JClass, + address: jlong, + size: jlong, +) { + try_unwrap_or_throw(&e, |_| { + let ipc_encoded_bytes = + unsafe { std::slice::from_raw_parts(address as *const u8, size as usize) }; + let start = Instant::now(); + let _batch = read_ipc_compressed_lz4(ipc_encoded_bytes)?; + println!("native decode batch in {:?}", start.elapsed()); + + // TODO return batch via FFI + + Ok(()) + }) +} + +#[no_mangle] +/// Used by Comet native shuffle reader +pub extern "system" fn Java_org_apache_comet_Native_decodeShuffleBlockZstd( + e: JNIEnv, + _class: JClass, + address: jlong, + size: jlong, +) { + try_unwrap_or_throw(&e, |_| { + let ipc_encoded_bytes = + unsafe { std::slice::from_raw_parts(address as *const u8, size as usize) }; + let start = Instant::now(); + let _batch = read_ipc_compressed_zstd(ipc_encoded_bytes)?; + println!("native decode batch in {:?}", start.elapsed()); + + // TODO return batch via FFI + + Ok(()) + }) +} diff --git a/native/core/src/execution/planner.rs b/native/core/src/execution/planner.rs index 0a7493354..199854b9c 100644 --- a/native/core/src/execution/planner.rs +++ b/native/core/src/execution/planner.rs @@ -1055,6 +1055,7 @@ impl PhysicalPlanner { Ok(SparkCompressionCodec::Zstd) => { Ok(CompressionCodec::Zstd(writer.compression_level)) } + Ok(SparkCompressionCodec::Lz4) => Ok(CompressionCodec::Lz4Frame), _ => Err(ExecutionError::GeneralError(format!( "Unsupported shuffle compression codec: {:?}", writer.codec diff --git a/native/core/src/execution/shuffle/mod.rs b/native/core/src/execution/shuffle/mod.rs index 8111f5eed..8c714ff68 100644 --- a/native/core/src/execution/shuffle/mod.rs +++ b/native/core/src/execution/shuffle/mod.rs @@ -19,4 +19,7 @@ mod list; mod map; pub mod row; mod shuffle_writer; -pub use shuffle_writer::{write_ipc_compressed, CompressionCodec, ShuffleWriterExec}; +pub use shuffle_writer::{ + read_ipc_compressed_lz4, read_ipc_compressed_zstd, write_ipc_compressed, CompressionCodec, + ShuffleWriterExec, +}; diff --git a/native/core/src/execution/shuffle/shuffle_writer.rs b/native/core/src/execution/shuffle/shuffle_writer.rs index 01117199e..2b38c9671 100644 --- a/native/core/src/execution/shuffle/shuffle_writer.rs +++ b/native/core/src/execution/shuffle/shuffle_writer.rs @@ -21,6 +21,7 @@ use crate::{ common::bit::ceil, errors::{CometError, CometResult}, }; +use arrow::ipc::reader::StreamReader; use arrow::{datatypes::*, ipc::writer::StreamWriter}; use async_trait::async_trait; use bytes::Buf; @@ -788,6 +789,8 @@ impl ShuffleRepartitioner { Partitioning::Hash(exprs, _) => { let (partition_starts, shuffled_partition_ids): (Vec, Vec) = { let mut timer = self.metrics.repart_time.timer(); + + // evaluate partition expressions let arrays = exprs .iter() .map(|expr| expr.evaluate(&input)?.into_array(input.num_rows())) @@ -1547,6 +1550,7 @@ impl Checksum { #[derive(Debug, Clone)] pub enum CompressionCodec { None, + Lz4Frame, Zstd(i32), } @@ -1575,6 +1579,23 @@ pub fn write_ipc_compressed( arrow_writer.finish()?; arrow_writer.into_inner()? } + CompressionCodec::Lz4Frame => { + // write IPC first without compression + let mut buffer = vec![]; + let mut arrow_writer = StreamWriter::try_new(&mut buffer, &batch.schema())?; + arrow_writer.write(batch)?; + arrow_writer.finish()?; + let ipc_encoded = arrow_writer.into_inner()?; + + // compress + let mut reader = Cursor::new(ipc_encoded); + let mut wtr = lz4_flex::frame::FrameEncoder::new(output); + std::io::copy(&mut reader, &mut wtr)?; + let output = wtr + .finish() + .map_err(|e| DataFusionError::Execution(format!("lz4 compression error: {}", e)))?; + output + } CompressionCodec::Zstd(level) => { let encoder = zstd::Encoder::new(output, *level)?; let mut arrow_writer = StreamWriter::try_new(encoder, &batch.schema())?; @@ -1599,6 +1620,20 @@ pub fn write_ipc_compressed( Ok((end_pos - start_pos) as usize) } +pub fn read_ipc_compressed_lz4(bytes: &[u8]) -> Result { + let decoder = lz4_flex::frame::FrameDecoder::new(bytes); + let mut reader = StreamReader::try_new(decoder, None)?; + // TODO check for None + reader.next().unwrap().map_err(|e| e.into()) +} + +pub fn read_ipc_compressed_zstd(bytes: &[u8]) -> Result { + let decoder = zstd::Decoder::new(bytes)?; + let mut reader = StreamReader::try_new(decoder, None)?; + // TODO check for None + reader.next().unwrap().map_err(|e| e.into()) +} + /// A stream that yields no record batches which represent end of output. pub struct EmptyStream { /// Schema representing the data @@ -1648,11 +1683,11 @@ mod test { #[test] #[cfg_attr(miri, ignore)] // miri can't call foreign function `ZSTD_createCCtx` - fn write_ipc_zstd() { + fn roundtrip_ipc_zstd() { let batch = create_batch(8192); let mut output = vec![]; let mut cursor = Cursor::new(&mut output); - write_ipc_compressed( + let length = write_ipc_compressed( &batch, &mut cursor, &CompressionCodec::Zstd(1), @@ -1660,6 +1695,37 @@ mod test { ) .unwrap(); assert_eq!(40218, output.len()); + + // generate file that can be tested on JVM side in org.apache.spark.CometShuffleCodecSuite + write_ipc_file("/tmp/shuffle.zstd", &output); + + let ipc_without_length_prefix = &output[8..length]; + let batch2 = read_ipc_compressed_lz4(ipc_without_length_prefix).unwrap(); + assert_eq!(batch, batch2); + } + + #[test] + #[cfg_attr(miri, ignore)] // miri can't call foreign function `ZSTD_createCCtx` + fn write_ipc_lz4_frame() { + let batch = create_batch(8192); + let mut output = vec![]; + let mut cursor = Cursor::new(&mut output); + write_ipc_compressed( + &batch, + &mut cursor, + &CompressionCodec::Lz4Frame, + &Time::default(), + ) + .unwrap(); + assert_eq!(61744, output.len()); + + // generate file that can be tested on JVM side in org.apache.spark.CometShuffleCodecSuite + write_ipc_file("/tmp/shuffle.lz4", &output); + } + + fn write_ipc_file(filename: &str, output: &[u8]) { + let mut file = File::create(filename).unwrap(); + file.write_all(&output).unwrap() } #[test] diff --git a/native/proto/src/proto/operator.proto b/native/proto/src/proto/operator.proto index 5cb2802da..08c16a802 100644 --- a/native/proto/src/proto/operator.proto +++ b/native/proto/src/proto/operator.proto @@ -85,6 +85,7 @@ message Limit { enum CompressionCodec { None = 0; Zstd = 1; + Lz4 = 2; } message ShuffleWriter { diff --git a/pom.xml b/pom.xml index cdc44a5ca..b49e7bdea 100644 --- a/pom.xml +++ b/pom.xml @@ -158,6 +158,12 @@ under the License. + + org.lz4 + lz4-java + 1.8.0 + + org.apache.arrow diff --git a/spark/pom.xml b/spark/pom.xml index ad7590dbc..111e73baf 100644 --- a/spark/pom.xml +++ b/spark/pom.xml @@ -135,6 +135,10 @@ under the License. arrow-c-data test + + org.lz4 + lz4-java + diff --git a/spark/src/main/scala/org/apache/comet/Native.scala b/spark/src/main/scala/org/apache/comet/Native.scala index 083c0f2b5..414e2ef8c 100644 --- a/spark/src/main/scala/org/apache/comet/Native.scala +++ b/spark/src/main/scala/org/apache/comet/Native.scala @@ -139,4 +139,22 @@ class Native extends NativeBase { * the size of the array. */ @native def sortRowPartitionsNative(addr: Long, size: Long): Unit + + /** + * Decompress and decode a native shuffle block that was compressed with LZ4. + * @param addr + * the address of the array of compressed and encoded bytes + * @param size + * the size of the array. + */ + @native def decodeShuffleBlockLz4(addr: Long, size: Long): Unit + + /** + * Decompress and decode a native shuffle block that was compressed with ZSTD. + * @param addr + * the address of the array of compressed and encoded bytes + * @param size + * the size of the array. + */ + @native def decodeShuffleBlockZstd(addr: Long, size: Long): Unit } diff --git a/common/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/ArrowReaderIterator.scala b/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/ArrowReaderIterator.scala similarity index 100% rename from common/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/ArrowReaderIterator.scala rename to spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/ArrowReaderIterator.scala diff --git a/common/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometBlockStoreShuffleReader.scala b/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometBlockStoreShuffleReader.scala similarity index 92% rename from common/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometBlockStoreShuffleReader.scala rename to spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometBlockStoreShuffleReader.scala index e026cbeb1..74c655950 100644 --- a/common/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometBlockStoreShuffleReader.scala +++ b/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometBlockStoreShuffleReader.scala @@ -21,21 +21,12 @@ package org.apache.spark.sql.comet.execution.shuffle import java.io.InputStream -import org.apache.spark.InterruptibleIterator -import org.apache.spark.MapOutputTracker -import org.apache.spark.SparkEnv -import org.apache.spark.TaskContext -import org.apache.spark.internal.Logging -import org.apache.spark.internal.config +import org.apache.spark.{InterruptibleIterator, MapOutputTracker, SparkEnv, TaskContext} +import org.apache.spark.internal.{config, Logging} import org.apache.spark.io.CompressionCodec import org.apache.spark.serializer.SerializerManager -import org.apache.spark.shuffle.BaseShuffleHandle -import org.apache.spark.shuffle.ShuffleReader -import org.apache.spark.shuffle.ShuffleReadMetricsReporter -import org.apache.spark.storage.BlockId -import org.apache.spark.storage.BlockManager -import org.apache.spark.storage.BlockManagerId -import org.apache.spark.storage.ShuffleBlockFetcherIterator +import org.apache.spark.shuffle.{BaseShuffleHandle, ShuffleReader, ShuffleReadMetricsReporter} +import org.apache.spark.storage.{BlockId, BlockManager, BlockManagerId, ShuffleBlockFetcherIterator} import org.apache.spark.util.CompletionIterator /** diff --git a/common/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometShuffleDependency.scala b/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometShuffleDependency.scala similarity index 100% rename from common/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometShuffleDependency.scala rename to spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometShuffleDependency.scala diff --git a/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometShuffleExchangeExec.scala b/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometShuffleExchangeExec.scala index 3a11b8b28..8485c0e12 100644 --- a/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometShuffleExchangeExec.scala +++ b/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometShuffleExchangeExec.scala @@ -554,16 +554,15 @@ class CometShuffleWriteProcessor( shuffleWriterBuilder.setOutputDataFile(dataFile) shuffleWriterBuilder.setOutputIndexFile(indexFile) - if (SparkEnv.get.conf.getBoolean("spark.shuffle.compress", true)) { - val codec = CometConf.COMET_EXEC_SHUFFLE_COMPRESSION_CODEC.get() match { - case "zstd" => CompressionCodec.Zstd - case other => throw new UnsupportedOperationException(s"invalid codec: $other") - } - shuffleWriterBuilder.setCodec(codec) - } else { - shuffleWriterBuilder.setCodec(CompressionCodec.None) + // TODO remove hard-coded compression levels + val (codec, level) = CometConf.COMET_EXEC_SHUFFLE_COMPRESSION_CODEC.get() match { + case "none" => (CompressionCodec.None, 0) + case "lz4" => (CompressionCodec.Lz4, 0) + case "zstd" => (CompressionCodec.Zstd, 1) + case other => throw new UnsupportedOperationException(s"invalid codec: $other") } - shuffleWriterBuilder.setCompressionLevel(CometConf.COMET_EXEC_SHUFFLE_COMPRESSION_LEVEL.get) + shuffleWriterBuilder.setCodec(codec) + shuffleWriterBuilder.setCompressionLevel(level) outputPartitioning match { case _: HashPartitioning => diff --git a/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometShuffleManager.scala b/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometShuffleManager.scala index b2cc2c2ba..1792311e6 100644 --- a/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometShuffleManager.scala +++ b/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometShuffleManager.scala @@ -29,8 +29,6 @@ import org.apache.spark.SparkConf import org.apache.spark.SparkEnv import org.apache.spark.TaskContext import org.apache.spark.internal.{config, Logging} -import org.apache.spark.internal.config.IO_COMPRESSION_CODEC -import org.apache.spark.io.CompressionCodec import org.apache.spark.shuffle._ import org.apache.spark.shuffle.api.ShuffleExecutorComponents import org.apache.spark.shuffle.sort.{BypassMergeSortShuffleHandle, SerializedShuffleHandle, SortShuffleManager, SortShuffleWriter} @@ -241,18 +239,6 @@ object CometShuffleManager extends Logging { executorComponents } - lazy val compressionCodecForShuffling: CompressionCodec = { - val sparkConf = SparkEnv.get.conf - val codecName = CometConf.COMET_EXEC_SHUFFLE_COMPRESSION_CODEC.get(SQLConf.get) - - // only zstd compression is supported at the moment - if (codecName != "zstd") { - logWarning( - s"Overriding config ${IO_COMPRESSION_CODEC}=${codecName} in shuffling, force using zstd") - } - CompressionCodec.createCodec(sparkConf, "zstd") - } - def shouldBypassMergeSort(conf: SparkConf, dep: ShuffleDependency[_, _, _]): Boolean = { // We cannot bypass sorting if we need to do map-side aggregation. if (dep.mapSideCombine) { diff --git a/common/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometShuffledRowRDD.scala b/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometShuffledRowRDD.scala similarity index 95% rename from common/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometShuffledRowRDD.scala rename to spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometShuffledRowRDD.scala index af78ed290..ba6fc588e 100644 --- a/common/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometShuffledRowRDD.scala +++ b/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometShuffledRowRDD.scala @@ -19,10 +19,10 @@ package org.apache.spark.sql.comet.execution.shuffle -import org.apache.spark.{Dependency, MapOutputTrackerMaster, Partition, Partitioner, ShuffleDependency, SparkEnv, TaskContext} +import org.apache.spark._ import org.apache.spark.rdd.RDD import org.apache.spark.shuffle.sort.SortShuffleManager -import org.apache.spark.sql.execution.{CoalescedMapperPartitionSpec, CoalescedPartitioner, CoalescedPartitionSpec, PartialMapperPartitionSpec, PartialReducerPartitionSpec, ShufflePartitionSpec} +import org.apache.spark.sql.execution._ import org.apache.spark.sql.execution.metric.{SQLMetric, SQLShuffleReadMetricsReporter} import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.vectorized.ColumnarBatch diff --git a/common/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/IpcInputStreamIterator.scala b/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/IpcInputStreamIterator.scala similarity index 93% rename from common/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/IpcInputStreamIterator.scala rename to spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/IpcInputStreamIterator.scala index d1d5af350..ee2ff0cc6 100644 --- a/common/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/IpcInputStreamIterator.scala +++ b/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/IpcInputStreamIterator.scala @@ -19,12 +19,9 @@ package org.apache.spark.sql.comet.execution.shuffle -import java.io.EOFException -import java.io.InputStream -import java.nio.ByteBuffer -import java.nio.ByteOrder -import java.nio.channels.Channels -import java.nio.channels.ReadableByteChannel +import java.io.{EOFException, InputStream} +import java.nio.{ByteBuffer, ByteOrder} +import java.nio.channels.{Channels, ReadableByteChannel} import org.apache.spark.TaskContext import org.apache.spark.internal.Logging @@ -109,6 +106,9 @@ case class IpcInputStreamIterator( val is = new LimitedInputStream(Channels.newInputStream(channel), currentIpcLength, false) currentLimitedInputStream = is + val ipcBytes = is.readAllBytes() + // TODO call Native.decodeShuffleBlockLz4 but it is in the wrong module + if (decompressingNeeded) { ShuffleUtils.compressionCodecForShuffling match { case Some(codec) => Channels.newChannel(codec.compressedInputStream(is)) diff --git a/common/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/ShuffleUtils.scala b/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/ShuffleUtils.scala similarity index 100% rename from common/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/ShuffleUtils.scala rename to spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/ShuffleUtils.scala diff --git a/spark/src/test/scala/org/apache/spark/CometShuffleCodecSuite.scala b/spark/src/test/scala/org/apache/spark/CometShuffleCodecSuite.scala new file mode 100644 index 000000000..007389099 --- /dev/null +++ b/spark/src/test/scala/org/apache/spark/CometShuffleCodecSuite.scala @@ -0,0 +1,62 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.spark + +import java.io.FileInputStream + +import org.apache.spark.sql.CometTestBase +import org.apache.spark.sql.comet.execution.shuffle.{ArrowReaderIterator, IpcInputStreamIterator} + +import org.apache.comet.CometConf + +/** + * Manual tests for testing compatibility of shuffle files generated from native tests. + */ +class CometShuffleCodecSuite extends CometTestBase { + + ignore("decode shuffle batch with zstd compression") { + withSQLConf(CometConf.COMET_EXEC_SHUFFLE_COMPRESSION_CODEC.key -> "zstd") { + // test file is created by shuffle_writer.rs test `write_ipc_zstd` + val is = new FileInputStream("/tmp/shuffle.zstd") + val ins = IpcInputStreamIterator(is, decompressingNeeded = true, TaskContext.get()) + assert(ins.hasNext) + val channel = ins.next() + val it = new ArrowReaderIterator(channel, "test") + assert(it.hasNext) + val batch = it.next() + assert(8192 == batch.numRows()) + } + } + + test("decode shuffle batch with lz4 compression") { + withSQLConf(CometConf.COMET_EXEC_SHUFFLE_COMPRESSION_CODEC.key -> "lz4") { + // test file is created by shuffle_writer.rs test `write_ipc_lz4` + val is = new FileInputStream("/tmp/shuffle.lz4") + val ins = IpcInputStreamIterator(is, decompressingNeeded = true, TaskContext.get()) + assert(ins.hasNext) + val channel = ins.next() + val it = new ArrowReaderIterator(channel, "test") + assert(it.hasNext) + val batch = it.next() + assert(8192 == batch.numRows()) + } + } + +}