Skip to content

Commit

Permalink
Stop passing Java config map into native createPlan (apache#1101)
Browse files Browse the repository at this point in the history
  • Loading branch information
andygrove authored Dec 4, 2024
1 parent 36a2307 commit 2671e0c
Show file tree
Hide file tree
Showing 3 changed files with 28 additions and 83 deletions.
69 changes: 16 additions & 53 deletions native/core/src/execution/jni_api.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,8 @@ use futures::poll;
use jni::{
errors::Result as JNIResult,
objects::{
JByteArray, JClass, JIntArray, JLongArray, JMap, JObject, JObjectArray, JPrimitiveArray,
JString, ReleaseMode,
JByteArray, JClass, JIntArray, JLongArray, JObject, JObjectArray, JPrimitiveArray, JString,
ReleaseMode,
},
sys::{jbyteArray, jint, jlong, jlongArray},
JNIEnv,
Expand Down Expand Up @@ -77,8 +77,6 @@ struct ExecutionContext {
pub input_sources: Vec<Arc<GlobalRef>>,
/// The record batch stream to pull results from
pub stream: Option<SendableRecordBatchStream>,
/// Configurations for DF execution
pub conf: HashMap<String, String>,
/// The Tokio runtime used for async.
pub runtime: Runtime,
/// Native metrics
Expand All @@ -103,11 +101,15 @@ pub unsafe extern "system" fn Java_org_apache_comet_Native_createPlan(
e: JNIEnv,
_class: JClass,
id: jlong,
config_object: JObject,
iterators: jobjectArray,
serialized_query: jbyteArray,
metrics_node: JObject,
comet_task_memory_manager_obj: JObject,
batch_size: jint,
debug_native: jboolean,
explain_native: jboolean,
worker_threads: jint,
blocking_threads: jint,
) -> jlong {
try_unwrap_or_throw(&e, |mut env| {
// Init JVM classes
Expand All @@ -121,36 +123,10 @@ pub unsafe extern "system" fn Java_org_apache_comet_Native_createPlan(
// Deserialize query plan
let spark_plan = serde::deserialize_op(bytes.as_slice())?;

// Sets up context
let mut configs = HashMap::new();

let config_map = JMap::from_env(&mut env, &config_object)?;
let mut map_iter = config_map.iter(&mut env)?;
while let Some((key, value)) = map_iter.next(&mut env)? {
let key: String = env.get_string(&JString::from(key)).unwrap().into();
let value: String = env.get_string(&JString::from(value)).unwrap().into();
configs.insert(key, value);
}

// Whether we've enabled additional debugging on the native side
let debug_native = parse_bool(&configs, "debug_native")?;
let explain_native = parse_bool(&configs, "explain_native")?;

let worker_threads = configs
.get("worker_threads")
.map(String::as_str)
.unwrap_or("4")
.parse::<usize>()?;
let blocking_threads = configs
.get("blocking_threads")
.map(String::as_str)
.unwrap_or("10")
.parse::<usize>()?;

// Use multi-threaded tokio runtime to prevent blocking spawned tasks if any
let runtime = tokio::runtime::Builder::new_multi_thread()
.worker_threads(worker_threads)
.max_blocking_threads(blocking_threads)
.worker_threads(worker_threads as usize)
.max_blocking_threads(blocking_threads as usize)
.enable_all()
.build()?;

Expand All @@ -171,7 +147,7 @@ pub unsafe extern "system" fn Java_org_apache_comet_Native_createPlan(
// We need to keep the session context alive. Some session state like temporary
// dictionaries are stored in session context. If it is dropped, the temporary
// dictionaries will be dropped as well.
let session = prepare_datafusion_session_context(&configs, task_memory_manager)?;
let session = prepare_datafusion_session_context(batch_size as usize, task_memory_manager)?;

let plan_creation_time = start.elapsed();

Expand All @@ -182,33 +158,24 @@ pub unsafe extern "system" fn Java_org_apache_comet_Native_createPlan(
scans: vec![],
input_sources,
stream: None,
conf: configs,
runtime,
metrics,
plan_creation_time,
session_ctx: Arc::new(session),
debug_native,
explain_native,
debug_native: debug_native == 1,
explain_native: explain_native == 1,
metrics_jstrings: HashMap::new(),
});

Ok(Box::into_raw(exec_context) as i64)
})
}

/// Parse Comet configs and configure DataFusion session context.
/// Configure DataFusion session context.
fn prepare_datafusion_session_context(
conf: &HashMap<String, String>,
batch_size: usize,
comet_task_memory_manager: Arc<GlobalRef>,
) -> CometResult<SessionContext> {
// Get the batch size from Comet JVM side
let batch_size = conf
.get("batch_size")
.ok_or(CometError::Internal(
"Config 'batch_size' is not specified from Comet JVM side".to_string(),
))?
.parse::<usize>()?;

let mut rt_config = RuntimeConfig::new().with_disk_manager(DiskManagerConfig::NewOs);

// Set Comet memory pool for native
Expand All @@ -218,7 +185,7 @@ fn prepare_datafusion_session_context(
// Get Datafusion configuration from Spark Execution context
// can be configured in Comet Spark JVM using Spark --conf parameters
// e.g: spark-shell --conf spark.datafusion.sql_parser.parse_float_as_decimal=true
let mut session_config = SessionConfig::new()
let session_config = SessionConfig::new()
.with_batch_size(batch_size)
// DataFusion partial aggregates can emit duplicate rows so we disable the
// skip partial aggregation feature because this is not compatible with Spark's
Expand All @@ -231,11 +198,7 @@ fn prepare_datafusion_session_context(
&ScalarValue::Float64(Some(1.1)),
);

for (key, value) in conf.iter().filter(|(k, _)| k.starts_with("datafusion.")) {
session_config = session_config.set_str(key, value);
}

let runtime = RuntimeEnv::try_new(rt_config).unwrap();
let runtime = RuntimeEnv::try_new(rt_config)?;

let mut session_ctx = SessionContext::new_with_config_rt(session_config, Arc::new(runtime));

Expand Down
32 changes: 6 additions & 26 deletions spark/src/main/scala/org/apache/comet/CometExecIterator.scala
Original file line number Diff line number Diff line change
Expand Up @@ -60,43 +60,23 @@ class CometExecIterator(
new CometBatchIterator(iterator, nativeUtil)
}.toArray
private val plan = {
val configs = createNativeConf
nativeLib.createPlan(
id,
configs,
cometBatchIterators,
protobufQueryPlan,
nativeMetrics,
new CometTaskMemoryManager(id))
new CometTaskMemoryManager(id),
batchSize = COMET_BATCH_SIZE.get(),
debug = COMET_DEBUG_ENABLED.get(),
explain = COMET_EXPLAIN_NATIVE_ENABLED.get(),
workerThreads = COMET_WORKER_THREADS.get(),
blockingThreads = COMET_BLOCKING_THREADS.get())
}

private var nextBatch: Option[ColumnarBatch] = None
private var currentBatch: ColumnarBatch = null
private var closed: Boolean = false

/**
* Creates a new configuration map to be passed to the native side.
*/
private def createNativeConf: java.util.HashMap[String, String] = {
val result = new java.util.HashMap[String, String]()
val conf = SparkEnv.get.conf

result.put("batch_size", String.valueOf(COMET_BATCH_SIZE.get()))
result.put("debug_native", String.valueOf(COMET_DEBUG_ENABLED.get()))
result.put("explain_native", String.valueOf(COMET_EXPLAIN_NATIVE_ENABLED.get()))
result.put("worker_threads", String.valueOf(COMET_WORKER_THREADS.get()))
result.put("blocking_threads", String.valueOf(COMET_BLOCKING_THREADS.get()))

// Strip mandatory prefix spark. which is not required for DataFusion session params
conf.getAll.foreach {
case (k, v) if k.startsWith("spark.datafusion") =>
result.put(k.replaceFirst("spark\\.", ""), v)
case _ =>
}

result
}

def getNextBatch(): Option[ColumnarBatch] = {
assert(partitionIndex >= 0 && partitionIndex < numParts)

Expand Down
10 changes: 6 additions & 4 deletions spark/src/main/scala/org/apache/comet/Native.scala
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,6 @@

package org.apache.comet

import java.util.Map

import org.apache.spark.CometTaskMemoryManager
import org.apache.spark.sql.comet.CometMetricNode

Expand All @@ -47,11 +45,15 @@ class Native extends NativeBase {
*/
@native def createPlan(
id: Long,
configMap: Map[String, String],
iterators: Array[CometBatchIterator],
plan: Array[Byte],
metrics: CometMetricNode,
taskMemoryManager: CometTaskMemoryManager): Long
taskMemoryManager: CometTaskMemoryManager,
batchSize: Int,
debug: Boolean,
explain: Boolean,
workerThreads: Int,
blockingThreads: Int): Long

/**
* Execute a native query plan based on given input Arrow arrays.
Expand Down

0 comments on commit 2671e0c

Please sign in to comment.