diff --git a/common/src/main/java/org/apache/comet/parquet/ColumnReader.java b/common/src/main/java/org/apache/comet/parquet/ColumnReader.java index d30befd74..8877ce8ca 100644 --- a/common/src/main/java/org/apache/comet/parquet/ColumnReader.java +++ b/common/src/main/java/org/apache/comet/parquet/ColumnReader.java @@ -25,7 +25,9 @@ import org.slf4j.Logger; import org.slf4j.LoggerFactory; -import org.apache.arrow.c.*; +import org.apache.arrow.c.ArrowArray; +import org.apache.arrow.c.ArrowSchema; +import org.apache.arrow.c.CometSchemaImporter; import org.apache.arrow.memory.BufferAllocator; import org.apache.arrow.memory.RootAllocator; import org.apache.arrow.vector.FieldVector; diff --git a/common/src/main/java/org/apache/comet/parquet/MetadataColumnReader.java b/common/src/main/java/org/apache/comet/parquet/MetadataColumnReader.java index 688970ae5..13b90e256 100644 --- a/common/src/main/java/org/apache/comet/parquet/MetadataColumnReader.java +++ b/common/src/main/java/org/apache/comet/parquet/MetadataColumnReader.java @@ -64,6 +64,7 @@ public void readBatch(int total) { FieldVector fieldVector = Data.importVector(allocator, array, schema, null); vector = new CometPlainVector(fieldVector, useDecimal128); } + vector.setNumValues(total); } diff --git a/native/core/src/execution/datafusion/expressions/bloom_filter_agg.rs b/native/core/src/execution/datafusion/expressions/bloom_filter_agg.rs index ed64b80e7..e6528a563 100644 --- a/native/core/src/execution/datafusion/expressions/bloom_filter_agg.rs +++ b/native/core/src/execution/datafusion/expressions/bloom_filter_agg.rs @@ -62,7 +62,17 @@ impl BloomFilterAgg { assert!(matches!(data_type, DataType::Binary)); Self { name: name.into(), - signature: Signature::exact(vec![DataType::Int64], Volatility::Immutable), + signature: Signature::uniform( + 1, + vec![ + DataType::Int8, + DataType::Int16, + DataType::Int32, + DataType::Int64, + DataType::Utf8, + ], + Volatility::Immutable, + ), expr, num_items: extract_i32_from_literal(num_items), num_bits: extract_i32_from_literal(num_bits), @@ -112,10 +122,25 @@ impl Accumulator for SparkBloomFilter { (0..arr.len()).try_for_each(|index| { let v = ScalarValue::try_from_array(arr, index)?; - if let ScalarValue::Int64(Some(value)) = v { - self.put_long(value); - } else { - unreachable!() + match v { + ScalarValue::Int8(Some(value)) => { + self.put_long(value as i64); + } + ScalarValue::Int16(Some(value)) => { + self.put_long(value as i64); + } + ScalarValue::Int32(Some(value)) => { + self.put_long(value as i64); + } + ScalarValue::Int64(Some(value)) => { + self.put_long(value); + } + ScalarValue::Utf8(Some(value)) => { + self.put_binary(value.as_bytes()); + } + _ => { + unreachable!() + } } Ok(()) }) diff --git a/native/core/src/execution/datafusion/util/spark_bloom_filter.rs b/native/core/src/execution/datafusion/util/spark_bloom_filter.rs index 22a84d854..35fa23b46 100644 --- a/native/core/src/execution/datafusion/util/spark_bloom_filter.rs +++ b/native/core/src/execution/datafusion/util/spark_bloom_filter.rs @@ -115,6 +115,23 @@ impl SparkBloomFilter { bit_changed } + pub fn put_binary(&mut self, item: &[u8]) -> bool { + // Here we first hash the input long element into 2 int hash values, h1 and h2, then produce + // n hash values by `h1 + i * h2` with 1 <= i <= num_hash_functions. + let h1 = spark_compatible_murmur3_hash(item, 0); + let h2 = spark_compatible_murmur3_hash(item, h1); + let bit_size = self.bits.bit_size() as i32; + let mut bit_changed = false; + for i in 1..=self.num_hash_functions { + let mut combined_hash = (h1 as i32).add_wrapping((i as i32).mul_wrapping(h2 as i32)); + if combined_hash < 0 { + combined_hash = !combined_hash; + } + bit_changed |= self.bits.set((combined_hash % bit_size) as usize) + } + bit_changed + } + pub fn might_contain_long(&self, item: i64) -> bool { let h1 = spark_compatible_murmur3_hash(item.to_le_bytes(), 0); let h2 = spark_compatible_murmur3_hash(item.to_le_bytes(), h1); diff --git a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala index 3805d418b..abb138b0d 100644 --- a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala +++ b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala @@ -769,11 +769,17 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde with CometExprShim val numBitsExpr = exprToProto(numBits, inputs, binding) val dataType = serializeDataType(bloom_filter.dataType) - // TODO: Support more types - // https://github.com/apache/datafusion-comet/issues/1023 if (childExpr.isDefined && - child.dataType - .isInstanceOf[LongType] && + (child.dataType + .isInstanceOf[ByteType] || + child.dataType + .isInstanceOf[ShortType] || + child.dataType + .isInstanceOf[IntegerType] || + child.dataType + .isInstanceOf[LongType] || + child.dataType + .isInstanceOf[StringType]) && numItemsExpr.isDefined && numBitsExpr.isDefined && dataType.isDefined) { diff --git a/spark/src/main/scala/org/apache/spark/Plugins.scala b/spark/src/main/scala/org/apache/spark/Plugins.scala index ea5ef1663..b2cfef9aa 100644 --- a/spark/src/main/scala/org/apache/spark/Plugins.scala +++ b/spark/src/main/scala/org/apache/spark/Plugins.scala @@ -42,6 +42,8 @@ import org.apache.comet.{CometConf, CometSparkSessionExtensions} * To enable this plugin, set the config "spark.plugins" to `org.apache.spark.CometPlugin`. */ class CometDriverPlugin extends DriverPlugin with Logging with ShimCometDriverPlugin { + private val EXECUTOR_MEMORY_DEFAULT = "1g" + override def init(sc: SparkContext, pluginContext: PluginContext): ju.Map[String, String] = { logInfo("CometDriverPlugin init") @@ -53,7 +55,7 @@ class CometDriverPlugin extends DriverPlugin with Logging with ShimCometDriverPl sc.getConf.getSizeAsMb(EXECUTOR_MEMORY_OVERHEAD.key) } else { // By default, executorMemory * spark.executor.memoryOverheadFactor, with minimum of 384MB - val executorMemory = sc.getConf.getSizeAsMb(EXECUTOR_MEMORY.key) + val executorMemory = sc.getConf.getSizeAsMb(EXECUTOR_MEMORY.key, EXECUTOR_MEMORY_DEFAULT) val memoryOverheadFactor = getMemoryOverheadFactor(sc.getConf) val memoryOverheadMinMib = getMemoryOverheadMinMib(sc.getConf) diff --git a/spark/src/test/scala/org/apache/comet/exec/CometExecSuite.scala b/spark/src/test/scala/org/apache/comet/exec/CometExecSuite.scala index a720842ce..99007d0c9 100644 --- a/spark/src/test/scala/org/apache/comet/exec/CometExecSuite.scala +++ b/spark/src/test/scala/org/apache/comet/exec/CometExecSuite.scala @@ -946,8 +946,12 @@ class CometExecSuite extends CometTestBase { (0 until 100) .map(_ => (Random.nextInt(), Random.nextInt() % 5)), "tbl") { - val df = sql("SELECT bloom_filter_agg(cast(_2 as long)) FROM tbl") - checkSparkAnswerAndOperator(df) + + (if (isSpark35Plus) Seq("tinyint", "short", "int", "long", "string") else Seq("long")) + .foreach { input_type => + val df = sql(f"SELECT bloom_filter_agg(cast(_2 as $input_type)) FROM tbl") + checkSparkAnswerAndOperator(df) + } } spark.sessionState.functionRegistry.dropFunction(funcId_bloom_filter_agg)