diff --git a/core/src/execution/datafusion/shuffle_writer.rs b/core/src/execution/datafusion/shuffle_writer.rs index fc15facb4..f836e3a40 100644 --- a/core/src/execution/datafusion/shuffle_writer.rs +++ b/core/src/execution/datafusion/shuffle_writer.rs @@ -460,32 +460,32 @@ fn append_columns( }; } - macro_rules! append_string_dict { - ($kt:ident) => {{ + macro_rules! append_byte_dict { + ($kt:ident, $byte_type:ty, $array_type:ty) => {{ match $kt.as_ref() { DataType::Int8 => { - append_dict!(Int8Type, StringDictionaryBuilder, StringArray) + append_dict!(Int8Type, GenericByteDictionaryBuilder, $array_type) } DataType::Int16 => { - append_dict!(Int16Type, StringDictionaryBuilder, StringArray) + append_dict!(Int16Type, GenericByteDictionaryBuilder, $array_type) } DataType::Int32 => { - append_dict!(Int32Type, StringDictionaryBuilder, StringArray) + append_dict!(Int32Type, GenericByteDictionaryBuilder, $array_type) } DataType::Int64 => { - append_dict!(Int64Type, StringDictionaryBuilder, StringArray) + append_dict!(Int64Type, GenericByteDictionaryBuilder, $array_type) } DataType::UInt8 => { - append_dict!(UInt8Type, StringDictionaryBuilder, StringArray) + append_dict!(UInt8Type, GenericByteDictionaryBuilder, $array_type) } DataType::UInt16 => { - append_dict!(UInt16Type, StringDictionaryBuilder, StringArray) + append_dict!(UInt16Type, GenericByteDictionaryBuilder, $array_type) } DataType::UInt32 => { - append_dict!(UInt32Type, StringDictionaryBuilder, StringArray) + append_dict!(UInt32Type, GenericByteDictionaryBuilder, $array_type) } DataType::UInt64 => { - append_dict!(UInt64Type, StringDictionaryBuilder, StringArray) + append_dict!(UInt64Type, GenericByteDictionaryBuilder, $array_type) } _ => unreachable!("Unknown key type for dictionary"), } @@ -522,7 +522,22 @@ fn append_columns( DataType::Dictionary(key_type, value_type) if matches!(value_type.as_ref(), DataType::Utf8) => { - append_string_dict!(key_type) + append_byte_dict!(key_type, GenericStringType, StringArray) + } + DataType::Dictionary(key_type, value_type) + if matches!(value_type.as_ref(), DataType::LargeUtf8) => + { + append_byte_dict!(key_type, GenericStringType, LargeStringArray) + } + DataType::Dictionary(key_type, value_type) + if matches!(value_type.as_ref(), DataType::Binary) => + { + append_byte_dict!(key_type, GenericBinaryType, BinaryArray) + } + DataType::Dictionary(key_type, value_type) + if matches!(value_type.as_ref(), DataType::LargeBinary) => + { + append_byte_dict!(key_type, GenericBinaryType, LargeBinaryArray) } DataType::Binary => append!(Binary), DataType::LargeBinary => append!(LargeBinary), @@ -1028,7 +1043,7 @@ macro_rules! primitive_dict_builder_helper { }; } -macro_rules! string_dict_builder_inner_helper { +macro_rules! byte_dict_builder_inner_helper { ($kt:ty, $capacity:ident, $builder:ident) => { Box::new($builder::<$kt>::with_capacity( $capacity, @@ -1068,28 +1083,28 @@ fn make_dict_builder(datatype: &DataType, capacity: usize) -> Box { - string_dict_builder_inner_helper!(Int16Type, capacity, StringDictionaryBuilder) + byte_dict_builder_inner_helper!(Int8Type, capacity, StringDictionaryBuilder) } DataType::Int16 => { - string_dict_builder_inner_helper!(Int16Type, capacity, StringDictionaryBuilder) + byte_dict_builder_inner_helper!(Int16Type, capacity, StringDictionaryBuilder) } DataType::Int32 => { - string_dict_builder_inner_helper!(Int32Type, capacity, StringDictionaryBuilder) + byte_dict_builder_inner_helper!(Int32Type, capacity, StringDictionaryBuilder) } DataType::Int64 => { - string_dict_builder_inner_helper!(Int64Type, capacity, StringDictionaryBuilder) + byte_dict_builder_inner_helper!(Int64Type, capacity, StringDictionaryBuilder) } DataType::UInt8 => { - string_dict_builder_inner_helper!(UInt8Type, capacity, StringDictionaryBuilder) + byte_dict_builder_inner_helper!(UInt8Type, capacity, StringDictionaryBuilder) } DataType::UInt16 => { - string_dict_builder_inner_helper!(UInt16Type, capacity, StringDictionaryBuilder) + byte_dict_builder_inner_helper!(UInt16Type, capacity, StringDictionaryBuilder) } DataType::UInt32 => { - string_dict_builder_inner_helper!(UInt32Type, capacity, StringDictionaryBuilder) + byte_dict_builder_inner_helper!(UInt32Type, capacity, StringDictionaryBuilder) } DataType::UInt64 => { - string_dict_builder_inner_helper!(UInt64Type, capacity, StringDictionaryBuilder) + byte_dict_builder_inner_helper!(UInt64Type, capacity, StringDictionaryBuilder) } _ => unreachable!(""), } @@ -1098,47 +1113,47 @@ fn make_dict_builder(datatype: &DataType, capacity: usize) -> Box { match key_type.as_ref() { - DataType::Int8 => string_dict_builder_inner_helper!( - Int16Type, + DataType::Int8 => byte_dict_builder_inner_helper!( + Int8Type, capacity, LargeStringDictionaryBuilder ), - DataType::Int16 => string_dict_builder_inner_helper!( + DataType::Int16 => byte_dict_builder_inner_helper!( Int16Type, capacity, LargeStringDictionaryBuilder ), - DataType::Int32 => string_dict_builder_inner_helper!( + DataType::Int32 => byte_dict_builder_inner_helper!( Int32Type, capacity, LargeStringDictionaryBuilder ), - DataType::Int64 => string_dict_builder_inner_helper!( + DataType::Int64 => byte_dict_builder_inner_helper!( Int64Type, capacity, LargeStringDictionaryBuilder ), - DataType::UInt8 => string_dict_builder_inner_helper!( + DataType::UInt8 => byte_dict_builder_inner_helper!( UInt8Type, capacity, LargeStringDictionaryBuilder ), DataType::UInt16 => { - string_dict_builder_inner_helper!( + byte_dict_builder_inner_helper!( UInt16Type, capacity, LargeStringDictionaryBuilder ) } DataType::UInt32 => { - string_dict_builder_inner_helper!( + byte_dict_builder_inner_helper!( UInt32Type, capacity, LargeStringDictionaryBuilder ) } DataType::UInt64 => { - string_dict_builder_inner_helper!( + byte_dict_builder_inner_helper!( UInt64Type, capacity, LargeStringDictionaryBuilder @@ -1147,6 +1162,90 @@ fn make_dict_builder(datatype: &DataType, capacity: usize) -> Box unreachable!(""), } } + DataType::Dictionary(key_type, value_type) + if matches!(value_type.as_ref(), DataType::Binary) => + { + match key_type.as_ref() { + DataType::Int8 => { + byte_dict_builder_inner_helper!(Int8Type, capacity, BinaryDictionaryBuilder) + } + DataType::Int16 => { + byte_dict_builder_inner_helper!(Int16Type, capacity, BinaryDictionaryBuilder) + } + DataType::Int32 => { + byte_dict_builder_inner_helper!(Int32Type, capacity, BinaryDictionaryBuilder) + } + DataType::Int64 => { + byte_dict_builder_inner_helper!(Int64Type, capacity, BinaryDictionaryBuilder) + } + DataType::UInt8 => { + byte_dict_builder_inner_helper!(UInt8Type, capacity, BinaryDictionaryBuilder) + } + DataType::UInt16 => { + byte_dict_builder_inner_helper!(UInt16Type, capacity, BinaryDictionaryBuilder) + } + DataType::UInt32 => { + byte_dict_builder_inner_helper!(UInt32Type, capacity, BinaryDictionaryBuilder) + } + DataType::UInt64 => { + byte_dict_builder_inner_helper!(UInt64Type, capacity, BinaryDictionaryBuilder) + } + _ => unreachable!(""), + } + } + DataType::Dictionary(key_type, value_type) + if matches!(value_type.as_ref(), DataType::LargeBinary) => + { + match key_type.as_ref() { + DataType::Int8 => byte_dict_builder_inner_helper!( + Int8Type, + capacity, + LargeBinaryDictionaryBuilder + ), + DataType::Int16 => byte_dict_builder_inner_helper!( + Int16Type, + capacity, + LargeBinaryDictionaryBuilder + ), + DataType::Int32 => byte_dict_builder_inner_helper!( + Int32Type, + capacity, + LargeBinaryDictionaryBuilder + ), + DataType::Int64 => byte_dict_builder_inner_helper!( + Int64Type, + capacity, + LargeBinaryDictionaryBuilder + ), + DataType::UInt8 => byte_dict_builder_inner_helper!( + UInt8Type, + capacity, + LargeBinaryDictionaryBuilder + ), + DataType::UInt16 => { + byte_dict_builder_inner_helper!( + UInt16Type, + capacity, + LargeBinaryDictionaryBuilder + ) + } + DataType::UInt32 => { + byte_dict_builder_inner_helper!( + UInt32Type, + capacity, + LargeBinaryDictionaryBuilder + ) + } + DataType::UInt64 => { + byte_dict_builder_inner_helper!( + UInt64Type, + capacity, + LargeBinaryDictionaryBuilder + ) + } + _ => unreachable!(""), + } + } t => panic!("Data type {t:?} is not currently supported"), } } diff --git a/core/src/execution/datafusion/spark_hash.rs b/core/src/execution/datafusion/spark_hash.rs index 0413e4559..aeefccf5b 100644 --- a/core/src/execution/datafusion/spark_hash.rs +++ b/core/src/execution/datafusion/spark_hash.rs @@ -292,6 +292,12 @@ pub fn create_hashes<'a>( DataType::LargeUtf8 => { hash_array!(LargeStringArray, col, hashes_buffer); } + DataType::Binary => { + hash_array!(BinaryArray, col, hashes_buffer); + } + DataType::LargeBinary => { + hash_array!(LargeBinaryArray, col, hashes_buffer); + } DataType::FixedSizeBinary(_) => { hash_array!(FixedSizeBinaryArray, col, hashes_buffer); } diff --git a/spark/src/test/scala/org/apache/comet/exec/CometShuffleSuite.scala b/spark/src/test/scala/org/apache/comet/exec/CometShuffleSuite.scala index acd424acd..0d7c73d42 100644 --- a/spark/src/test/scala/org/apache/comet/exec/CometShuffleSuite.scala +++ b/spark/src/test/scala/org/apache/comet/exec/CometShuffleSuite.scala @@ -63,6 +63,25 @@ abstract class CometShuffleSuiteBase extends CometTestBase with AdaptiveSparkPla import testImplicits._ + test("Native shuffle with dictionary of binary") { + Seq("true", "false").foreach { dictionaryEnabled => + withSQLConf( + CometConf.COMET_EXEC_ENABLED.key -> "true", + CometConf.COMET_EXEC_SHUFFLE_ENABLED.key -> "true", + CometConf.COMET_COLUMNAR_SHUFFLE_ENABLED.key -> "false") { + withParquetTable( + (0 until 1000).map(i => (i % 5, (i % 5).toString.getBytes())), + "tbl", + dictionaryEnabled.toBoolean) { + val shuffled = sql("SELECT * FROM tbl").repartition(2, $"_2") + + checkCometExchange(shuffled, 1, true) + checkSparkAnswer(shuffled) + } + } + } + } + test("columnar shuffle on nested struct including nulls") { Seq(10, 201).foreach { numPartitions => Seq("1.0", "10.0").foreach { ratio =>