diff --git a/common/src/main/java/org/apache/comet/CometOutOfMemoryError.java b/common/src/main/java/org/apache/comet/CometOutOfMemoryError.java new file mode 100644 index 000000000..8a9e8d1db --- /dev/null +++ b/common/src/main/java/org/apache/comet/CometOutOfMemoryError.java @@ -0,0 +1,27 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.comet; + +/** OOM error specific for Comet memory management */ +public class CometOutOfMemoryError extends OutOfMemoryError { + public CometOutOfMemoryError(String msg) { + super(msg); + } +} diff --git a/core/src/errors.rs b/core/src/errors.rs index e99af7aa6..1d5766cb9 100644 --- a/core/src/errors.rs +++ b/core/src/errors.rs @@ -101,7 +101,11 @@ pub enum CometError { #[from] source: std::num::ParseFloatError, }, - + #[error(transparent)] + BoolFormat { + #[from] + source: std::str::ParseBoolError, + }, #[error(transparent)] Format { #[from] diff --git a/core/src/execution/jni_api.rs b/core/src/execution/jni_api.rs index 1d55d3f92..20f98a3a4 100644 --- a/core/src/execution/jni_api.rs +++ b/core/src/execution/jni_api.rs @@ -42,7 +42,7 @@ use jni::{ }; use std::{collections::HashMap, sync::Arc, task::Poll}; -use super::{serde, utils::SparkArrowConvert}; +use super::{serde, utils::SparkArrowConvert, CometMemoryPool}; use crate::{ errors::{try_unwrap_or_throw, CometError, CometResult}, @@ -103,6 +103,7 @@ pub unsafe extern "system" fn Java_org_apache_comet_Native_createPlan( iterators: jobjectArray, serialized_query: jbyteArray, metrics_node: JObject, + comet_task_memory_manager_obj: JObject, ) -> jlong { try_unwrap_or_throw(&e, |mut env| { // Init JVM classes @@ -147,11 +148,13 @@ pub unsafe extern "system" fn Java_org_apache_comet_Native_createPlan( let input_source = Arc::new(jni_new_global_ref!(env, input_source)?); input_sources.push(input_source); } + let task_memory_manager = + Arc::new(jni_new_global_ref!(env, comet_task_memory_manager_obj)?); // We need to keep the session context alive. Some session state like temporary // dictionaries are stored in session context. If it is dropped, the temporary // dictionaries will be dropped as well. - let session = prepare_datafusion_session_context(&configs)?; + let session = prepare_datafusion_session_context(&configs, task_memory_manager)?; let exec_context = Box::new(ExecutionContext { id, @@ -175,6 +178,7 @@ pub unsafe extern "system" fn Java_org_apache_comet_Native_createPlan( /// Parse Comet configs and configure DataFusion session context. fn prepare_datafusion_session_context( conf: &HashMap, + comet_task_memory_manager: Arc, ) -> CometResult { // Get the batch size from Comet JVM side let batch_size = conf @@ -186,18 +190,30 @@ fn prepare_datafusion_session_context( let mut rt_config = RuntimeConfig::new().with_disk_manager(DiskManagerConfig::NewOs); - // Set up memory limit if specified - if conf.contains_key("memory_limit") { - let memory_limit = conf.get("memory_limit").unwrap().parse::()?; - - let memory_fraction = conf - .get("memory_fraction") - .ok_or(CometError::Internal( - "Config 'memory_fraction' is not specified from Comet JVM side".to_string(), - ))? - .parse::()?; - - rt_config = rt_config.with_memory_limit(memory_limit, memory_fraction); + // Check if we are using unified memory manager integrated with Spark. Default to false if not + // set. + let use_unified_memory_manager = conf + .get("use_unified_memory_manager") + .map(String::as_str) + .unwrap_or("false") + .parse::()?; + + if use_unified_memory_manager { + // Set Comet memory pool for native + let memory_pool = CometMemoryPool::new(comet_task_memory_manager); + rt_config = rt_config.with_memory_pool(Arc::new(memory_pool)); + } else { + // Use the memory pool from DF + if conf.contains_key("memory_limit") { + let memory_limit = conf.get("memory_limit").unwrap().parse::()?; + let memory_fraction = conf + .get("memory_fraction") + .ok_or(CometError::Internal( + "Config 'memory_fraction' is not specified from Comet JVM side".to_string(), + ))? + .parse::()?; + rt_config = rt_config.with_memory_limit(memory_limit, memory_fraction) + } } // Get Datafusion configuration from Spark Execution context diff --git a/core/src/execution/memory_pool.rs b/core/src/execution/memory_pool.rs new file mode 100644 index 000000000..ff2369095 --- /dev/null +++ b/core/src/execution/memory_pool.rs @@ -0,0 +1,119 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::{ + fmt::{Debug, Formatter, Result as FmtResult}, + sync::{ + atomic::{AtomicUsize, Ordering::Relaxed}, + Arc, + }, +}; + +use jni::objects::GlobalRef; + +use datafusion::{ + common::DataFusionError, + execution::memory_pool::{MemoryPool, MemoryReservation}, +}; + +use crate::{ + errors::CometResult, + jvm_bridge::{jni_call, JVMClasses}, +}; + +/// A DataFusion `MemoryPool` implementation for Comet. Internally this is +/// implemented via delegating calls to [`crate::jvm_bridge::CometTaskMemoryManager`]. +pub struct CometMemoryPool { + task_memory_manager_handle: Arc, + used: AtomicUsize, +} + +impl Debug for CometMemoryPool { + fn fmt(&self, f: &mut Formatter<'_>) -> FmtResult { + f.debug_struct("CometMemoryPool") + .field("used", &self.used.load(Relaxed)) + .finish() + } +} + +impl CometMemoryPool { + pub fn new(task_memory_manager_handle: Arc) -> CometMemoryPool { + Self { + task_memory_manager_handle, + used: AtomicUsize::new(0), + } + } + + fn acquire(&self, additional: usize) -> CometResult { + let mut env = JVMClasses::get_env(); + let handle = self.task_memory_manager_handle.as_obj(); + unsafe { + jni_call!(&mut env, + comet_task_memory_manager(handle).acquire_memory(additional as i64) -> i64) + } + } + + fn release(&self, size: usize) -> CometResult<()> { + let mut env = JVMClasses::get_env(); + let handle = self.task_memory_manager_handle.as_obj(); + unsafe { + jni_call!(&mut env, comet_task_memory_manager(handle).release_memory(size as i64) -> ()) + } + } +} + +unsafe impl Send for CometMemoryPool {} +unsafe impl Sync for CometMemoryPool {} + +impl MemoryPool for CometMemoryPool { + fn grow(&self, _: &MemoryReservation, additional: usize) { + self.acquire(additional) + .unwrap_or_else(|_| panic!("Failed to acquire {} bytes", additional)); + self.used.fetch_add(additional, Relaxed); + } + + fn shrink(&self, _: &MemoryReservation, size: usize) { + self.release(size) + .unwrap_or_else(|_| panic!("Failed to release {} bytes", size)); + self.used.fetch_sub(size, Relaxed); + } + + fn try_grow(&self, _: &MemoryReservation, additional: usize) -> Result<(), DataFusionError> { + if additional > 0 { + let acquired = self.acquire(additional)?; + // If the number of bytes we acquired is less than the requested, return an error, + // and hopefully will trigger spilling from the caller side. + if acquired < additional as i64 { + // Release the acquired bytes before throwing error + self.release(acquired as usize)?; + + return Err(DataFusionError::Execution(format!( + "Failed to acquire {} bytes, only got {}. Reserved: {}", + additional, + acquired, + self.reserved(), + ))); + } + self.used.fetch_add(additional, Relaxed); + } + Ok(()) + } + + fn reserved(&self) -> usize { + self.used.load(Relaxed) + } +} diff --git a/core/src/execution/mod.rs b/core/src/execution/mod.rs index 4c57ad8eb..b3be83b5f 100644 --- a/core/src/execution/mod.rs +++ b/core/src/execution/mod.rs @@ -29,6 +29,9 @@ pub(crate) mod sort; mod timezone; pub(crate) mod utils; +mod memory_pool; +pub use memory_pool::*; + // Include generated modules from .proto files. #[allow(missing_docs)] pub mod spark_expression { diff --git a/core/src/jvm_bridge/comet_task_memory_manager.rs b/core/src/jvm_bridge/comet_task_memory_manager.rs new file mode 100644 index 000000000..a79a5b67d --- /dev/null +++ b/core/src/jvm_bridge/comet_task_memory_manager.rs @@ -0,0 +1,62 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use jni::{ + errors::Result as JniResult, + objects::{JClass, JMethodID}, + signature::{Primitive, ReturnType}, + JNIEnv, +}; + +use crate::jvm_bridge::get_global_jclass; + +/// A wrapper which delegate acquire/release memory calls to the +/// JVM side `CometTaskMemoryManager`. +#[derive(Debug)] +pub struct CometTaskMemoryManager<'a> { + pub class: JClass<'a>, + pub method_acquire_memory: JMethodID, + pub method_release_memory: JMethodID, + + pub method_acquire_memory_ret: ReturnType, + pub method_release_memory_ret: ReturnType, +} + +impl<'a> CometTaskMemoryManager<'a> { + pub const JVM_CLASS: &'static str = "org/apache/spark/CometTaskMemoryManager"; + + pub fn new(env: &mut JNIEnv<'a>) -> JniResult> { + let class = get_global_jclass(env, Self::JVM_CLASS)?; + + let result = CometTaskMemoryManager { + class, + method_acquire_memory: env.get_method_id( + Self::JVM_CLASS, + "acquireMemory", + "(J)J".to_string(), + )?, + method_release_memory: env.get_method_id( + Self::JVM_CLASS, + "releaseMemory", + "(J)V".to_string(), + )?, + method_acquire_memory_ret: ReturnType::Primitive(Primitive::Long), + method_release_memory_ret: ReturnType::Primitive(Primitive::Void), + }; + Ok(result) + } +} diff --git a/core/src/jvm_bridge/mod.rs b/core/src/jvm_bridge/mod.rs index 7a2882e30..41376f03b 100644 --- a/core/src/jvm_bridge/mod.rs +++ b/core/src/jvm_bridge/mod.rs @@ -73,7 +73,7 @@ macro_rules! jni_call { let ret = $env.call_method_unchecked($obj, method_id, ret_type, args); // Check if JVM has thrown any exception, and handle it if so. - let result = if let Some(exception) = $crate::jvm_bridge::check_exception($env)? { + let result = if let Some(exception) = $crate::jvm_bridge::check_exception($env).unwrap() { Err(exception.into()) } else { $crate::jvm_bridge::jni_map_error!($env, ret) @@ -194,10 +194,12 @@ mod comet_exec; pub use comet_exec::*; mod batch_iterator; mod comet_metric_node; +mod comet_task_memory_manager; use crate::{errors::CometError, JAVA_VM}; use batch_iterator::CometBatchIterator; pub use comet_metric_node::*; +pub use comet_task_memory_manager::*; /// The JVM classes that are used in the JNI calls. pub struct JVMClasses<'a> { @@ -216,6 +218,9 @@ pub struct JVMClasses<'a> { pub comet_exec: CometExec<'a>, /// The CometBatchIterator class. Used for iterating over the batches. pub comet_batch_iterator: CometBatchIterator<'a>, + /// The CometTaskMemoryManager used for interacting with JVM side to + /// acquire & release native memory. + pub comet_task_memory_manager: CometTaskMemoryManager<'a>, } unsafe impl<'a> Send for JVMClasses<'a> {} @@ -261,6 +266,7 @@ impl JVMClasses<'_> { comet_metric_node: CometMetricNode::new(env).unwrap(), comet_exec: CometExec::new(env).unwrap(), comet_batch_iterator: CometBatchIterator::new(env).unwrap(), + comet_task_memory_manager: CometTaskMemoryManager::new(env).unwrap(), } }); } diff --git a/dev/ensure-jars-have-correct-contents.sh b/dev/ensure-jars-have-correct-contents.sh index 1ab09a5f8..5543093ff 100755 --- a/dev/ensure-jars-have-correct-contents.sh +++ b/dev/ensure-jars-have-correct-contents.sh @@ -80,6 +80,8 @@ allowed_expr+="|^org/apache/spark/shuffle/comet/.*$" allowed_expr+="|^org/apache/spark/sql/$" allowed_expr+="|^org/apache/spark/CometPlugin.class$" allowed_expr+="|^org/apache/spark/CometDriverPlugin.*$" +allowed_expr+="|^org/apache/spark/CometTaskMemoryManager.class$" +allowed_expr+="|^org/apache/spark/CometTaskMemoryManager.*$" allowed_expr+=")" declare -i bad_artifacts=0 diff --git a/spark/src/main/java/org/apache/spark/CometTaskMemoryManager.java b/spark/src/main/java/org/apache/spark/CometTaskMemoryManager.java new file mode 100644 index 000000000..96fa3b432 --- /dev/null +++ b/spark/src/main/java/org/apache/spark/CometTaskMemoryManager.java @@ -0,0 +1,77 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.spark; + +import java.io.IOException; + +import org.apache.spark.memory.MemoryConsumer; +import org.apache.spark.memory.MemoryMode; +import org.apache.spark.memory.TaskMemoryManager; + +/** + * A adapter class that is used by Comet native to acquire & release memory through Spark's unified + * memory manager. This assumes Spark's off-heap memory mode is enabled. + */ +public class CometTaskMemoryManager { + /** The id uniquely identifies the native plan this memory manager is associated to */ + private final long id; + + private final TaskMemoryManager internal; + private final NativeMemoryConsumer nativeMemoryConsumer; + + public CometTaskMemoryManager(long id) { + this.id = id; + this.internal = TaskContext$.MODULE$.get().taskMemoryManager(); + this.nativeMemoryConsumer = new NativeMemoryConsumer(); + } + + // Called by Comet native through JNI. + // Returns the actual amount of memory (in bytes) granted. + public long acquireMemory(long size) { + return internal.acquireExecutionMemory(size, nativeMemoryConsumer); + } + + // Called by Comet native through JNI + public void releaseMemory(long size) { + internal.releaseExecutionMemory(size, nativeMemoryConsumer); + } + + /** + * A dummy memory consumer that does nothing when spilling. At the moment, Comet native doesn't + * share the same API as Spark and cannot trigger spill when acquire memory. Therefore, when + * acquiring memory from native or JVM, spilling can only be triggered from JVM operators. + */ + private class NativeMemoryConsumer extends MemoryConsumer { + protected NativeMemoryConsumer() { + super(CometTaskMemoryManager.this.internal, 0, MemoryMode.OFF_HEAP); + } + + @Override + public long spill(long size, MemoryConsumer trigger) throws IOException { + // No spilling + return 0; + } + + @Override + public String toString() { + return String.format("NativeMemoryConsumer(id=%)", id); + } + } +} diff --git a/spark/src/main/scala/org/apache/comet/CometExecIterator.scala b/spark/src/main/scala/org/apache/comet/CometExecIterator.scala index 20b2d384a..b3604c9e0 100644 --- a/spark/src/main/scala/org/apache/comet/CometExecIterator.scala +++ b/spark/src/main/scala/org/apache/comet/CometExecIterator.scala @@ -54,7 +54,13 @@ class CometExecIterator( }.toArray private val plan = { val configs = createNativeConf - nativeLib.createPlan(id, configs, cometBatchIterators, protobufQueryPlan, nativeMetrics) + nativeLib.createPlan( + id, + configs, + cometBatchIterators, + protobufQueryPlan, + nativeMetrics, + new CometTaskMemoryManager(id)) } private var nextBatch: Option[ColumnarBatch] = None @@ -83,6 +89,12 @@ class CometExecIterator( val conf = SparkEnv.get.conf val maxMemory = CometSparkSessionExtensions.getCometMemoryOverhead(conf) + // Only enable unified memory manager when off-heap mode is enabled. Otherwise, + // we'll use the built-in memory pool from DF, and initializes with `memory_limit` + // and `memory_fraction` below. + result.put( + "use_unified_memory_manager", + String.valueOf(conf.get("spark.memory.offHeap.enabled", "false"))) result.put("memory_limit", String.valueOf(maxMemory)) result.put("memory_fraction", String.valueOf(COMET_EXEC_MEMORY_FRACTION.get())) result.put("batch_size", String.valueOf(COMET_BATCH_SIZE.get())) diff --git a/spark/src/main/scala/org/apache/comet/Native.scala b/spark/src/main/scala/org/apache/comet/Native.scala index 1564fc991..97ded91b2 100644 --- a/spark/src/main/scala/org/apache/comet/Native.scala +++ b/spark/src/main/scala/org/apache/comet/Native.scala @@ -21,6 +21,7 @@ package org.apache.comet import java.util.Map +import org.apache.spark.CometTaskMemoryManager import org.apache.spark.sql.comet.CometMetricNode class Native extends NativeBase { @@ -38,6 +39,9 @@ class Native extends NativeBase { * the bytes of serialized SparkPlan. * @param metrics * the native metrics of SparkPlan. + * @param taskMemoryManager + * the task-level memory manager that is responsible for tracking memory usage across JVM and + * native side. * @return * the address to native query plan. */ @@ -46,7 +50,8 @@ class Native extends NativeBase { configMap: Map[String, String], iterators: Array[CometBatchIterator], plan: Array[Byte], - metrics: CometMetricNode): Long + metrics: CometMetricNode, + taskMemoryManager: CometTaskMemoryManager): Long /** * Execute a native query plan based on given input Arrow arrays. diff --git a/spark/src/test/scala/org/apache/comet/exec/CometColumnarShuffleSuite.scala b/spark/src/test/scala/org/apache/comet/exec/CometColumnarShuffleSuite.scala index fec6197d6..a9b29e6b7 100644 --- a/spark/src/test/scala/org/apache/comet/exec/CometColumnarShuffleSuite.scala +++ b/spark/src/test/scala/org/apache/comet/exec/CometColumnarShuffleSuite.scala @@ -53,7 +53,8 @@ abstract class CometColumnarShuffleSuite extends CometTestBase with AdaptiveSpar CometConf.COMET_EXEC_SHUFFLE_SPILL_THRESHOLD.key -> numElementsForceSpillThreshold.toString, CometConf.COMET_EXEC_ENABLED.key -> "false", CometConf.COMET_COLUMNAR_SHUFFLE_ENABLED.key -> "true", - CometConf.COMET_EXEC_SHUFFLE_ENABLED.key -> "true") { + CometConf.COMET_EXEC_SHUFFLE_ENABLED.key -> "true", + CometConf.COMET_COLUMNAR_SHUFFLE_MEMORY_SIZE.key -> "1536m") { testFun } } diff --git a/spark/src/test/scala/org/apache/spark/sql/CometTPCDSQuerySuite.scala b/spark/src/test/scala/org/apache/spark/sql/CometTPCDSQuerySuite.scala index 65ea8ba64..265235ffe 100644 --- a/spark/src/test/scala/org/apache/spark/sql/CometTPCDSQuerySuite.scala +++ b/spark/src/test/scala/org/apache/spark/sql/CometTPCDSQuerySuite.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql import org.apache.spark.SparkConf +import org.apache.spark.internal.config.{MEMORY_OFFHEAP_ENABLED, MEMORY_OFFHEAP_SIZE} import org.apache.comet.CometConf @@ -150,10 +151,11 @@ class CometTPCDSQuerySuite "org.apache.spark.sql.comet.execution.shuffle.CometShuffleManager") conf.set(CometConf.COMET_ENABLED.key, "true") conf.set(CometConf.COMET_EXEC_ENABLED.key, "true") - conf.set(CometConf.COMET_MEMORY_OVERHEAD.key, "2g") conf.set(CometConf.COMET_EXEC_ALL_OPERATOR_ENABLED.key, "true") conf.set(CometConf.COMET_EXEC_ALL_EXPR_ENABLED.key, "true") conf.set(CometConf.COMET_EXEC_SHUFFLE_ENABLED.key, "true") + conf.set(MEMORY_OFFHEAP_ENABLED.key, "true") + conf.set(MEMORY_OFFHEAP_SIZE.key, "2g") conf } diff --git a/spark/src/test/scala/org/apache/spark/sql/CometTPCHQuerySuite.scala b/spark/src/test/scala/org/apache/spark/sql/CometTPCHQuerySuite.scala index 9ea0218c2..954269a8a 100644 --- a/spark/src/test/scala/org/apache/spark/sql/CometTPCHQuerySuite.scala +++ b/spark/src/test/scala/org/apache/spark/sql/CometTPCHQuerySuite.scala @@ -25,6 +25,7 @@ import java.nio.file.{Files, Paths} import scala.collection.JavaConverters._ import org.apache.spark.{SparkConf, SparkContext} +import org.apache.spark.internal.config.{MEMORY_OFFHEAP_ENABLED, MEMORY_OFFHEAP_SIZE} import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.catalyst.util.{fileToString, resourceToString, stringToFile} import org.apache.spark.sql.internal.SQLConf @@ -87,10 +88,11 @@ class CometTPCHQuerySuite extends QueryTest with CometTPCBase with SQLQueryTestH "org.apache.spark.sql.comet.execution.shuffle.CometShuffleManager") conf.set(CometConf.COMET_ENABLED.key, "true") conf.set(CometConf.COMET_EXEC_ENABLED.key, "true") - conf.set(CometConf.COMET_MEMORY_OVERHEAD.key, "2g") conf.set(CometConf.COMET_EXEC_ALL_OPERATOR_ENABLED.key, "true") conf.set(CometConf.COMET_EXEC_ALL_EXPR_ENABLED.key, "true") conf.set(CometConf.COMET_EXEC_SHUFFLE_ENABLED.key, "true") + conf.set(MEMORY_OFFHEAP_ENABLED.key, "true") + conf.set(MEMORY_OFFHEAP_SIZE.key, "2g") } protected override def createSparkSession: TestSparkSession = { diff --git a/spark/src/test/scala/org/apache/spark/sql/CometTestBase.scala b/spark/src/test/scala/org/apache/spark/sql/CometTestBase.scala index 38a8d7d2f..e31be8c28 100644 --- a/spark/src/test/scala/org/apache/spark/sql/CometTestBase.scala +++ b/spark/src/test/scala/org/apache/spark/sql/CometTestBase.scala @@ -23,9 +23,7 @@ import scala.concurrent.duration._ import scala.reflect.ClassTag import scala.reflect.runtime.universe.TypeTag -import org.scalactic.source.Position import org.scalatest.BeforeAndAfterEach -import org.scalatest.Tag import org.apache.hadoop.fs.Path import org.apache.parquet.column.ParquetProperties @@ -35,6 +33,7 @@ import org.apache.parquet.hadoop.ParquetWriter import org.apache.parquet.hadoop.example.ExampleParquetWriter import org.apache.parquet.schema.{MessageType, MessageTypeParser} import org.apache.spark._ +import org.apache.spark.internal.config.{MEMORY_OFFHEAP_ENABLED, MEMORY_OFFHEAP_SIZE, SHUFFLE_MANAGER} import org.apache.spark.sql.comet.{CometBatchScanExec, CometBroadcastExchangeExec, CometExec, CometScanExec, CometScanWrapper, CometSinkPlaceHolder} import org.apache.spark.sql.comet.execution.shuffle.{CometColumnarShuffle, CometNativeShuffle, CometShuffleExchangeExec} import org.apache.spark.sql.execution.{ColumnarToRowExec, InputAdapter, SparkPlan, WholeStageCodegenExec} @@ -65,28 +64,20 @@ abstract class CometTestBase val conf = new SparkConf() conf.set("spark.hadoop.fs.file.impl", classOf[DebugFilesystem].getName) conf.set(SQLConf.SHUFFLE_PARTITIONS, 10) // reduce parallelism in tests - conf.set("spark.shuffle.manager", shuffleManager) + conf.set(SQLConf.ANSI_ENABLED.key, "false") + conf.set(SHUFFLE_MANAGER, shuffleManager) + conf.set(MEMORY_OFFHEAP_ENABLED.key, "true") + conf.set(MEMORY_OFFHEAP_SIZE.key, "2g") + conf.set(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key, "1g") + conf.set(SQLConf.ADAPTIVE_AUTO_BROADCASTJOIN_THRESHOLD.key, "1g") + conf.set(CometConf.COMET_ENABLED.key, "true") + conf.set(CometConf.COMET_EXEC_ENABLED.key, "true") + conf.set(CometConf.COMET_EXEC_ALL_OPERATOR_ENABLED.key, "true") + conf.set(CometConf.COMET_EXEC_ALL_EXPR_ENABLED.key, "true") conf.set(CometConf.COMET_MEMORY_OVERHEAD.key, "2g") conf } - override protected def test(testName: String, testTags: Tag*)(testFun: => Any)(implicit - pos: Position): Unit = { - super.test(testName, testTags: _*) { - withSQLConf( - CometConf.COMET_ENABLED.key -> "true", - CometConf.COMET_EXEC_ENABLED.key -> "true", - CometConf.COMET_EXEC_ALL_OPERATOR_ENABLED.key -> "true", - CometConf.COMET_EXEC_ALL_EXPR_ENABLED.key -> "true", - CometConf.COMET_COLUMNAR_SHUFFLE_MEMORY_SIZE.key -> "2g", - SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "1g", - SQLConf.ADAPTIVE_AUTO_BROADCASTJOIN_THRESHOLD.key -> "1g", - SQLConf.ANSI_ENABLED.key -> "false") { - testFun - } - } - } - /** * A helper function for comparing Comet DataFrame with Spark result using absolute tolerance. */