From 37f17a5a5a0202c1e7ce95d9875bdb6e253b5df8 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Tue, 20 Feb 2024 14:08:50 -0800 Subject: [PATCH] feat: Pull based native execution --- .../org/apache/comet/vector/NativeUtil.scala | 16 +- core/src/errors.rs | 9 + core/src/execution/datafusion/planner.rs | 83 +++++--- core/src/execution/jni_api.rs | 183 +++++------------- core/src/execution/operators/scan.rs | 164 ++++++++++++++-- core/src/jvm_bridge/batch_iterator.rs | 46 +++++ core/src/jvm_bridge/mod.rs | 15 +- .../org/apache/comet/CometBatchIterator.java | 56 ++++++ .../org/apache/comet/CometExecIterator.scala | 93 ++------- .../main/scala/org/apache/comet/Native.scala | 53 +---- .../comet/parquet/ParquetReadSuite.scala | 3 +- 11 files changed, 405 insertions(+), 316 deletions(-) create mode 100644 core/src/jvm_bridge/batch_iterator.rs create mode 100644 spark/src/main/java/org/apache/comet/CometBatchIterator.java diff --git a/common/src/main/scala/org/apache/comet/vector/NativeUtil.scala b/common/src/main/scala/org/apache/comet/vector/NativeUtil.scala index 1e27ed8f0..4bb63e501 100644 --- a/common/src/main/scala/org/apache/comet/vector/NativeUtil.scala +++ b/common/src/main/scala/org/apache/comet/vector/NativeUtil.scala @@ -39,11 +39,14 @@ class NativeUtil { * @param batch * the input Comet columnar batch * @return - * a list containing pairs of memory addresses in the format of (address of Arrow array, - * address of Arrow schema) + * a list containing number of rows + pairs of memory addresses in the format of (address of + * Arrow array, address of Arrow schema) */ def exportBatch(batch: ColumnarBatch): Array[Long] = { - val vectors = (0 until batch.numCols()).flatMap { index => + val exportedVectors = mutable.ArrayBuffer.empty[Long] + exportedVectors += batch.numRows() + + (0 until batch.numCols()).foreach { index => batch.column(index) match { case a: CometVector => val valueVector = a.getValueVector @@ -63,7 +66,8 @@ class NativeUtil { arrowArray, arrowSchema) - Seq((arrowArray, arrowSchema)) + exportedVectors += arrowArray.memoryAddress() + exportedVectors += arrowSchema.memoryAddress() case c => throw new SparkException( "Comet execution only takes Arrow Arrays, but got " + @@ -71,9 +75,7 @@ class NativeUtil { } } - vectors.flatMap { pair => - Seq(pair._1.memoryAddress(), pair._2.memoryAddress()) - }.toArray + exportedVectors.toArray } /** diff --git a/core/src/errors.rs b/core/src/errors.rs index 0da2c9c74..16ed7c312 100644 --- a/core/src/errors.rs +++ b/core/src/errors.rs @@ -159,6 +159,15 @@ impl From for DataFusionError { } } +impl From for ExecutionError { + fn from(value: CometError) -> Self { + match value { + CometError::Execution { source } => source, + _ => ExecutionError::GeneralError(value.to_string()), + } + } +} + impl jni::errors::ToException for CometError { fn to_exception(&self) -> Exception { match self { diff --git a/core/src/execution/datafusion/planner.rs b/core/src/execution/datafusion/planner.rs index c13272453..2feaacebf 100644 --- a/core/src/execution/datafusion/planner.rs +++ b/core/src/execution/datafusion/planner.rs @@ -47,6 +47,7 @@ use datafusion_physical_expr::{ AggregateExpr, ScalarFunctionExpr, }; use itertools::Itertools; +use jni::objects::GlobalRef; use num::{BigInt, ToPrimitive}; use crate::{ @@ -70,7 +71,7 @@ use crate::{ operators::expand::CometExpandExec, shuffle_writer::ShuffleWriterExec, }, - operators::{CopyExec, ExecutionError, InputBatch, ScanExec}, + operators::{CopyExec, ExecutionError, ScanExec}, serde::to_arrow_datatype, spark_expression, spark_expression::{ @@ -88,6 +89,8 @@ type PhyAggResult = Result>, ExecutionError>; type PhyExprResult = Result, String)>, ExecutionError>; type PartitionPhyExprResult = Result>, ExecutionError>; +pub const TEST_EXEC_CONTEXT_ID: i64 = -1; + /// The query planner for converting Spark query plans to DataFusion query plans. pub struct PhysicalPlanner { // The execution context id of this planner. @@ -105,7 +108,7 @@ impl PhysicalPlanner { pub fn new() -> Self { let execution_props = ExecutionProps::new(); Self { - exec_context_id: -1, + exec_context_id: TEST_EXEC_CONTEXT_ID, execution_props, } } @@ -612,24 +615,28 @@ impl PhysicalPlanner { /// Create a DataFusion physical plan from Spark physical plan. /// - /// Note that we need `input_batches` parameter because we need to know the exact schema (not - /// only data type but also dictionary-encoding) at `ScanExec`s. It is because some DataFusion - /// operators, e.g., `ProjectionExec`, gets child operator schema during initialization and - /// uses it later for `RecordBatch`. We may be able to get rid of it once `RecordBatch` - /// relaxes schema check. + /// `inputs` is a vector of input source IDs. It is used to create `ScanExec`s. Each `ScanExec` + /// will be assigned a unique ID from `inputs` and the ID will be used to identify the input + /// source at JNI API. + /// + /// Note that `ScanExec` will pull initial input batch during initialization. It is because we + /// need to know the exact schema (not only data type but also dictionary-encoding) at + /// `ScanExec`s. It is because some DataFusion operators, e.g., `ProjectionExec`, gets child + /// operator schema during initialization and uses it later for `RecordBatch`. We may be + /// able to get rid of it once `RecordBatch` relaxes schema check. /// /// Note that we return created `Scan`s which will be kept at JNI API. JNI calls will use it to /// feed in new input batch from Spark JVM side. pub fn create_plan<'a>( &'a self, spark_plan: &'a Operator, - input_batches: &mut Vec, + inputs: &mut Vec>, ) -> Result<(Vec, Arc), ExecutionError> { let children = &spark_plan.children; match spark_plan.op_struct.as_ref().unwrap() { OpStruct::Projection(project) => { assert!(children.len() == 1); - let (scans, child) = self.create_plan(&children[0], input_batches)?; + let (scans, child) = self.create_plan(&children[0], inputs)?; let exprs: PhyExprResult = project .project_list .iter() @@ -643,7 +650,7 @@ impl PhysicalPlanner { } OpStruct::Filter(filter) => { assert!(children.len() == 1); - let (scans, child) = self.create_plan(&children[0], input_batches)?; + let (scans, child) = self.create_plan(&children[0], inputs)?; let predicate = self.create_expr(filter.predicate.as_ref().unwrap(), child.schema())?; @@ -651,7 +658,7 @@ impl PhysicalPlanner { } OpStruct::HashAgg(agg) => { assert!(children.len() == 1); - let (scans, child) = self.create_plan(&children[0], input_batches)?; + let (scans, child) = self.create_plan(&children[0], inputs)?; let group_exprs: PhyExprResult = agg .grouping_exprs @@ -716,13 +723,13 @@ impl PhysicalPlanner { OpStruct::Limit(limit) => { assert!(children.len() == 1); let num = limit.limit; - let (scans, child) = self.create_plan(&children[0], input_batches)?; + let (scans, child) = self.create_plan(&children[0], inputs)?; Ok((scans, Arc::new(LocalLimitExec::new(child, num as usize)))) } OpStruct::Sort(sort) => { assert!(children.len() == 1); - let (scans, child) = self.create_plan(&children[0], input_batches)?; + let (scans, child) = self.create_plan(&children[0], inputs)?; let exprs: Result, ExecutionError> = sort .sort_orders @@ -741,21 +748,32 @@ impl PhysicalPlanner { } OpStruct::Scan(scan) => { let fields = scan.fields.iter().map(to_arrow_datatype).collect_vec(); - if input_batches.is_empty() { + + // If it is not test execution context for unit test, we should have at least one + // input source + if self.exec_context_id != TEST_EXEC_CONTEXT_ID && inputs.is_empty() { return Err(ExecutionError::GeneralError( - "No input batch for scan".to_string(), + "No input for scan".to_string(), )); } - // Consumes the first input batch source for the scan - let input_batch = input_batches.remove(0); + + // Consumes the first input source for the scan + let input_source = if self.exec_context_id == TEST_EXEC_CONTEXT_ID + && inputs.is_empty() + { + // For unit test, we will set input batch to scan directly by `set_input_batch`. + None + } else { + Some(inputs.remove(0)) + }; // The `ScanExec` operator will take actual arrays from Spark during execution - let scan = ScanExec::new(input_batch, fields); + let scan = ScanExec::new(self.exec_context_id, input_source, fields)?; Ok((vec![scan.clone()], Arc::new(scan))) } OpStruct::ShuffleWriter(writer) => { assert!(children.len() == 1); - let (scans, child) = self.create_plan(&children[0], input_batches)?; + let (scans, child) = self.create_plan(&children[0], inputs)?; let partitioning = self .create_partitioning(writer.partitioning.as_ref().unwrap(), child.schema())?; @@ -772,7 +790,7 @@ impl PhysicalPlanner { } OpStruct::Expand(expand) => { assert!(children.len() == 1); - let (scans, child) = self.create_plan(&children[0], input_batches)?; + let (scans, child) = self.create_plan(&children[0], inputs)?; let mut projections = vec![]; let mut projection = vec![]; @@ -805,6 +823,18 @@ impl PhysicalPlanner { .collect(); let schema = Arc::new(Schema::new(fields)); + // `Expand` operator keeps the input batch and expands it to multiple output + // batches. However, `ScanExec` will reuse input arrays for the next + // input batch. Therefore, we need to copy the input batch to avoid + // the data corruption. Note that we only need to copy the input batch + // if the child operator is `ScanExec`, because other operators after `ScanExec` + // will create new arrays for the output batch. + let child = if child.as_any().downcast_ref::().is_some() { + Arc::new(CopyExec::new(child)) + } else { + child + }; + Ok(( scans, Arc::new(CometExpandExec::new(projections, child, schema)), @@ -997,9 +1027,9 @@ mod tests { let values = Int32Array::from(vec![0, 1, 2, 3]); let input_array = DictionaryArray::new(keys, Arc::new(values)); let input_batch = InputBatch::Batch(vec![Arc::new(input_array)], row_count); - let mut input_batches = vec![input_batch]; - let (mut scans, datafusion_plan) = planner.create_plan(&op, &mut input_batches).unwrap(); + let (mut scans, datafusion_plan) = planner.create_plan(&op, &mut vec![]).unwrap(); + scans[0].set_input_batch(input_batch); let session_ctx = SessionContext::new(); let task_ctx = session_ctx.task_ctx(); @@ -1077,9 +1107,11 @@ mod tests { let values = StringArray::from(vec!["foo", "bar", "hello", "comet"]); let input_array = DictionaryArray::new(keys, Arc::new(values)); let input_batch = InputBatch::Batch(vec![Arc::new(input_array)], row_count); - let mut input_batches = vec![input_batch]; - let (mut scans, datafusion_plan) = planner.create_plan(&op, &mut input_batches).unwrap(); + let (mut scans, datafusion_plan) = planner.create_plan(&op, &mut vec![]).unwrap(); + + // Scan's schema is determined by the input batch, so we need to set it before execution. + scans[0].set_input_batch(input_batch); let session_ctx = SessionContext::new(); let task_ctx = session_ctx.task_ctx(); @@ -1147,8 +1179,7 @@ mod tests { let op = create_filter(op_scan, 0); let planner = PhysicalPlanner::new(); - let mut input_batches = vec![InputBatch::EOF]; - let (mut scans, datafusion_plan) = planner.create_plan(&op, &mut input_batches).unwrap(); + let (mut scans, datafusion_plan) = planner.create_plan(&op, &mut vec![]).unwrap(); let scan = &mut scans[0]; scan.set_input_batch(InputBatch::EOF); diff --git a/core/src/execution/jni_api.rs b/core/src/execution/jni_api.rs index 831f78838..fae213bca 100644 --- a/core/src/execution/jni_api.rs +++ b/core/src/execution/jni_api.rs @@ -17,9 +17,7 @@ //! Define JNI APIs which can be called from Java/Scala. -use crate::execution::operators::{InputBatch, ScanExec}; use arrow::{ - array::{make_array, Array, ArrayData, ArrayRef}, datatypes::DataType as ArrowDataType, ffi::{FFI_ArrowArray, FFI_ArrowSchema}, }; @@ -32,13 +30,12 @@ use datafusion::{ physical_plan::{display::DisplayableExecutionPlan, ExecutionPlan, SendableRecordBatchStream}, prelude::{SessionConfig, SessionContext}, }; -use datafusion_common::DataFusionError; use futures::poll; use jni::{ errors::Result as JNIResult, objects::{ - AutoElements, JBooleanArray, JByteArray, JClass, JIntArray, JLongArray, JMap, JObject, - JObjectArray, JPrimitiveArray, JString, ReleaseMode, + JByteArray, JClass, JIntArray, JLongArray, JMap, JObject, JObjectArray, JPrimitiveArray, + JString, ReleaseMode, }, sys::{jbyteArray, jint, jlong, jlongArray}, JNIEnv, @@ -59,10 +56,11 @@ use crate::{ use futures::stream::StreamExt; use jni::{ objects::GlobalRef, - sys::{jboolean, jbooleanArray, jdouble, jintArray, jobjectArray, jstring}, + sys::{jboolean, jdouble, jintArray, jobjectArray, jstring}, }; use tokio::runtime::Runtime; +use crate::execution::operators::ScanExec; use log::info; /// Comet native execution context. Kept alive across JNI calls. @@ -75,6 +73,8 @@ struct ExecutionContext { pub root_op: Option>, /// The input sources for the DataFusion plan pub scans: Vec, + /// The global reference of input sources for the DataFusion plan + pub input_sources: Vec>, /// The record batch stream to pull results from pub stream: Option, /// The FFI arrays. We need to keep them alive here. @@ -100,6 +100,7 @@ pub unsafe extern "system" fn Java_org_apache_comet_Native_createPlan( _class: JClass, id: jlong, config_object: JObject, + iterators: jobjectArray, serialized_query: jbyteArray, metrics_node: JObject, ) -> jlong { @@ -137,6 +138,16 @@ pub unsafe extern "system" fn Java_org_apache_comet_Native_createPlan( let metrics = Arc::new(jni_new_global_ref!(env, metrics_node)?); + // Get the global references of input sources + let mut input_sources = vec![]; + let iter_array = JObjectArray::from_raw(iterators); + let num_inputs = env.get_array_length(&iter_array)?; + for i in 0..num_inputs { + let input_source = env.get_object_array_element(&iter_array, i)?; + let input_source = Arc::new(jni_new_global_ref!(env, input_source)?); + input_sources.push(input_source); + } + // We need to keep the session context alive. Some session state like temporary // dictionaries are stored in session context. If it is dropped, the temporary // dictionaries will be dropped as well. @@ -147,6 +158,7 @@ pub unsafe extern "system" fn Java_org_apache_comet_Native_createPlan( spark_plan, root_op: None, scans: vec![], + input_sources, stream: None, ffi_arrays: vec![], conf: configs, @@ -164,7 +176,7 @@ pub unsafe extern "system" fn Java_org_apache_comet_Native_createPlan( fn prepare_datafusion_session_context( conf: &HashMap, ) -> CometResult { - // Get the batch size from Boson JVM side + // Get the batch size from Comet JVM side let batch_size = conf .get("batch_size") .ok_or(CometError::Internal( @@ -212,10 +224,9 @@ fn prepare_datafusion_session_context( /// Prepares arrow arrays for output. fn prepare_output( env: &mut JNIEnv, - output: Result, + output_batch: RecordBatch, exec_context: &mut ExecutionContext, ) -> CometResult { - let output_batch = output?; let results = output_batch.columns(); let num_rows = output_batch.num_rows(); @@ -260,6 +271,20 @@ fn prepare_output( Ok(long_array.into_raw()) } +/// Pull the next input from JVM. Note that we cannot pull input batches in +/// `ScanStream.poll_next` when the execution stream is polled for output. +/// Because the input source could be another native execution stream, which +/// will be executed in another tokio blocking thread. It causes JNI throw +/// Java exception. So we pull input batches here and insert them into scan +/// operators before polling the stream, +#[inline] +fn pull_input_batches(exec_context: &mut ExecutionContext) -> Result<(), CometError> { + exec_context.scans.iter_mut().try_for_each(|scan| { + scan.get_next_batch()?; + Ok::<(), CometError>(()) + }) +} + /// Accept serialized query plan and the addresses of Arrow Arrays from Spark, /// then execute the query. Return addresses of arrow vector. /// # Safety @@ -269,76 +294,11 @@ pub unsafe extern "system" fn Java_org_apache_comet_Native_executePlan( e: JNIEnv, _class: JClass, exec_context: jlong, - addresses_array: jobjectArray, - finishes: jbooleanArray, - batch_rows: jint, ) -> jlongArray { - try_unwrap_or_throw(&e, |mut env| unsafe { + try_unwrap_or_throw(&e, |mut env| { + // Retrieve the query let exec_context = get_execution_context(exec_context); - let addresses = JObjectArray::from_raw(addresses_array); - let num_addresses = env.get_array_length(&addresses)? as usize; - - let mut all_inputs: Vec> = Vec::with_capacity(num_addresses); - - for i in 0..num_addresses { - let mut inputs: Vec = vec![]; - - let inner_addresses = env.get_object_array_element(&addresses, i as i32)?.into(); - let inner_address_array: AutoElements = - env.get_array_elements(&inner_addresses, ReleaseMode::NoCopyBack)?; - - let num_inner_address = inner_address_array.len(); - assert_eq!( - num_inner_address % 2, - 0, - "Arrow Array addresses are invalid!" - ); - - let num_arrays = num_inner_address / 2; - let array_elements = inner_address_array.as_ptr(); - - let mut i: usize = 0; - while i < num_arrays { - let array_ptr = *(array_elements.add(i * 2)); - let schema_ptr = *(array_elements.add(i * 2 + 1)); - let array_data = ArrayData::from_spark((array_ptr, schema_ptr))?; - - if exec_context.debug_native { - // Validate the array data from JVM. - array_data.validate_full().expect("Invalid array data"); - } - - inputs.push(make_array(array_data)); - i += 1; - } - - all_inputs.push(inputs); - } - - // Prepares the input batches. - let array = JBooleanArray::from_raw(finishes); - let eofs = env.get_array_elements(&array, ReleaseMode::NoCopyBack)?; - let eof_flags = eofs.as_ptr(); - - // Whether reaching the end of input batches. - let mut finished = true; - let mut input_batches = all_inputs - .into_iter() - .enumerate() - .map(|(idx, inputs)| { - let eof = eof_flags.add(idx); - - if *eof == 1 { - InputBatch::EOF - } else { - finished = false; - InputBatch::new(inputs, Some(batch_rows as usize)) - } - }) - .collect::>(); - - // Retrieve the query let exec_context_id = exec_context.id; // Initialize the execution stream. @@ -346,8 +306,10 @@ pub unsafe extern "system" fn Java_org_apache_comet_Native_executePlan( // query plan, we need to defer stream initialization to first time execution. if exec_context.root_op.is_none() { let planner = PhysicalPlanner::new().with_exec_id(exec_context_id); - let (scans, root_op) = - planner.create_plan(&exec_context.spark_plan, &mut input_batches)?; + let (scans, root_op) = planner.create_plan( + &exec_context.spark_plan, + &mut exec_context.input_sources.clone(), + )?; exec_context.root_op = Some(root_op.clone()); exec_context.scans = scans; @@ -366,15 +328,8 @@ pub unsafe extern "system" fn Java_org_apache_comet_Native_executePlan( .execute(0, task_ctx)?; exec_context.stream = Some(stream); } else { - input_batches - .into_iter() - .enumerate() - .for_each(|(idx, input_batch)| { - let scan = &mut exec_context.scans[idx]; - - // Set inputs at `Scan` node. - scan.set_input_batch(input_batch); - }); + // Pull input batches + pull_input_batches(exec_context)?; } loop { @@ -384,7 +339,7 @@ pub unsafe extern "system" fn Java_org_apache_comet_Native_executePlan( match poll_output { Poll::Ready(Some(output)) => { - return prepare_output(&mut env, output, exec_context); + return prepare_output(&mut env, output?, exec_context); } Poll::Ready(None) => { // Reaches EOF of output. @@ -397,23 +352,18 @@ pub unsafe extern "system" fn Java_org_apache_comet_Native_executePlan( return Ok(long_array.into_raw()); } - // After reaching the end of any input, a poll pending means there are more than - // one blocking operators, we don't need go back-forth - // between JVM/Native. Just keeping polling. - Poll::Pending if finished => { + // A poll pending means there are more than one blocking operators, + // we don't need go back-forth between JVM/Native. Just keeping polling. + Poll::Pending => { // Update metrics update_metrics(&mut env, exec_context)?; + // Pull input batches + pull_input_batches(exec_context)?; + // Output not ready yet continue; } - // Not reaching the end of input yet, so a poll pending means there are blocking - // operators. Just returning to keep reading next input. - Poll::Pending => { - // Update metrics - update_metrics(&mut env, exec_context)?; - return return_pending(env); - } } } }) @@ -425,37 +375,6 @@ fn return_pending(env: JNIEnv) -> Result { Ok(long_array.into_raw()) } -#[no_mangle] -/// Peeks into next output if any. -pub extern "system" fn Java_org_apache_comet_Native_peekNext( - e: JNIEnv, - _class: JClass, - exec_context: jlong, -) -> jlongArray { - try_unwrap_or_throw(&e, |mut env| { - // Retrieve the query - let exec_context = get_execution_context(exec_context); - - if exec_context.stream.is_none() { - // Plan is not initialized yet. - return return_pending(env); - } - - // Polling the stream. - let next_item = exec_context.stream.as_mut().unwrap().next(); - let poll_output = exec_context.runtime.block_on(async { poll!(next_item) }); - - match poll_output { - Poll::Ready(Some(output)) => prepare_output(&mut env, output, exec_context), - _ => { - // Update metrics - update_metrics(&mut env, exec_context)?; - return_pending(env) - } - } - }) -} - #[no_mangle] /// Drop the native query plan object and context object. pub extern "system" fn Java_org_apache_comet_Native_releasePlan( @@ -507,7 +426,7 @@ fn get_execution_context<'a>(id: i64) -> &'a mut ExecutionContext { } } -/// Used by Boson shuffle external sorter to write sorted records to disk. +/// Used by Comet shuffle external sorter to write sorted records to disk. /// # Safety /// This function is inheritly unsafe since it deals with raw pointers passed from JNI. #[no_mangle] @@ -577,7 +496,7 @@ pub unsafe extern "system" fn Java_org_apache_comet_Native_writeSortedFileNative } #[no_mangle] -/// Used by Boson shuffle external sorter to sort in-memory row partition ids. +/// Used by Comet shuffle external sorter to sort in-memory row partition ids. pub extern "system" fn Java_org_apache_comet_Native_sortRowPartitionsNative( e: JNIEnv, _class: JClass, diff --git a/core/src/execution/operators/scan.rs b/core/src/execution/operators/scan.rs index f80db6c56..9f85de80f 100644 --- a/core/src/execution/operators/scan.rs +++ b/core/src/execution/operators/scan.rs @@ -26,33 +26,62 @@ use futures::Stream; use itertools::Itertools; use arrow::compute::{cast_with_options, CastOptions}; -use arrow_array::{ArrayRef, RecordBatch, RecordBatchOptions}; +use arrow_array::{make_array, ArrayRef, RecordBatch, RecordBatchOptions}; +use arrow_data::ArrayData; use arrow_schema::{DataType, Field, Schema, SchemaRef}; +use crate::{ + errors::CometError, + execution::{ + datafusion::planner::TEST_EXEC_CONTEXT_ID, operators::ExecutionError, + utils::SparkArrowConvert, + }, + jvm_bridge::{jni_call, JVMClasses}, +}; use datafusion::{ execution::TaskContext, physical_expr::*, physical_plan::{ExecutionPlan, *}, }; use datafusion_common::{DataFusionError, Result as DataFusionResult}; +use jni::{ + objects::{GlobalRef, JLongArray, JObject, ReleaseMode}, + sys::jlongArray, +}; #[derive(Debug, Clone)] pub struct ScanExec { - pub batch: Arc>>, + /// The ID of the execution context that owns this subquery. We use this ID to retrieve the JVM + /// environment `JNIEnv` from the execution context. + pub exec_context_id: i64, + /// The input source of scan node. It is a global reference of JVM `CometBatchIterator` object. + pub input_source: Option>, + /// The data types of columns of the input batch. Converted from Spark schema. pub data_types: Vec, + /// The input batch of input data. Used to determine the schema of the input data. + /// It is also used in unit test to mock the input data from JVM. + pub batch: Arc>>, } impl ScanExec { - pub fn new(batch: InputBatch, data_types: Vec) -> Self { - Self { - batch: Arc::new(Mutex::new(Some(batch))), - data_types, - } - } + pub fn new( + exec_context_id: i64, + input_source: Option>, + data_types: Vec, + ) -> Result { + // Scan's schema is determined by the input batch, so we need to set it before execution. + let first_batch = if let Some(input_source) = input_source.as_ref() { + ScanExec::get_next(exec_context_id, input_source.as_obj())? + } else { + InputBatch::EOF + }; - /// Feeds input batch into this `Scan`. - pub fn set_input_batch(&mut self, input: InputBatch) { - *self.batch.try_lock().unwrap() = Some(input); + Ok(Self { + exec_context_id, + input_source, + data_types, + batch: Arc::new(Mutex::new(Some(first_batch))), + }) } /// Checks if the input data type `dt` is a dictionary type with primitive value type. @@ -74,6 +103,98 @@ impl ScanExec { dt.clone() } + + /// Feeds input batch into this `Scan`. Only used in unit test. + pub fn set_input_batch(&mut self, input: InputBatch) { + *self.batch.try_lock().unwrap() = Some(input); + } + + /// Pull next input batch from JVM. + pub fn get_next_batch(&mut self) -> Result<(), CometError> { + let mut current_batch = self.batch.try_lock().unwrap(); + + if self.input_source.is_none() { + // This is a unit test. We don't need to call JNI. + return Ok(()); + } + + if current_batch.is_none() { + let next_batch = ScanExec::get_next( + self.exec_context_id, + self.input_source.as_ref().unwrap().as_obj(), + )?; + *current_batch = Some(next_batch); + } + + Ok(()) + } + + /// Invokes JNI call to get next batch. + fn get_next(exec_context_id: i64, iter: &JObject) -> Result { + if exec_context_id == TEST_EXEC_CONTEXT_ID { + // This is a unit test. We don't need to call JNI. + return Ok(InputBatch::EOF); + } + + let mut env = JVMClasses::get_env(); + + if iter.is_null() { + return Err(CometError::from(ExecutionError::GeneralError(format!( + "Null batch iterator object. Plan id: {}", + exec_context_id + )))); + } + + let batch_object: JObject = unsafe { + jni_call!(&mut env, + comet_batch_iterator(iter).next() -> JObject)? + }; + + if batch_object.is_null() { + return Err(CometError::from(ExecutionError::GeneralError(format!( + "Null batch object. Plan id: {}", + exec_context_id + )))); + } + + let batch_object = unsafe { JLongArray::from_raw(batch_object.as_raw() as jlongArray) }; + + let addresses = unsafe { env.get_array_elements(&batch_object, ReleaseMode::NoCopyBack)? }; + + let mut inputs: Vec = vec![]; + + // First element is the number of rows. + let num_rows = unsafe { *addresses.as_ptr() as i64 }; + + if num_rows < 0 { + return Ok(InputBatch::EOF); + } + + let array_num = addresses.len() - 1; + if array_num % 2 != 0 { + return Err(CometError::Internal(format!( + "Invalid number of Arrow Array addresses: {}", + array_num + ))); + } + + let num_arrays = array_num / 2; + let array_elements = unsafe { addresses.as_ptr().add(1) }; + + let mut i: usize = 0; + while i < num_arrays { + let array_ptr = unsafe { *(array_elements.add(i * 2)) }; + let schema_ptr = unsafe { *(array_elements.add(i * 2 + 1)) }; + let array_data = ArrayData::from_spark((array_ptr, schema_ptr))?; + + // TODO: validate array input data + + inputs.push(make_array(array_data)); + i += 1; + } + + Ok(InputBatch::new(inputs, Some(num_rows as usize))) + } } impl ExecutionPlan for ScanExec { @@ -214,19 +335,22 @@ impl Stream for ScanStream { fn poll_next(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll> { let mut scan_batch = self.scan.batch.try_lock().unwrap(); let input_batch = &*scan_batch; + + let input_batch = if let Some(batch) = input_batch { + batch + } else { + return Poll::Pending; + }; + let result = match input_batch { - // Input batch is not ready. - None => Poll::Pending, - Some(batch) => match batch { - InputBatch::EOF => Poll::Ready(None), - InputBatch::Batch(columns, num_rows) => { - Poll::Ready(Some(self.build_record_batch(columns, *num_rows))) - } - }, + InputBatch::EOF => Poll::Ready(None), + InputBatch::Batch(columns, num_rows) => { + Poll::Ready(Some(self.build_record_batch(columns, *num_rows))) + } }; - // Reset the current input batch so it won't be processed again *scan_batch = None; + result } } diff --git a/core/src/jvm_bridge/batch_iterator.rs b/core/src/jvm_bridge/batch_iterator.rs new file mode 100644 index 000000000..474e4fdcf --- /dev/null +++ b/core/src/jvm_bridge/batch_iterator.rs @@ -0,0 +1,46 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use super::get_global_jclass; +use jni::{ + errors::Result as JniResult, + objects::{JClass, JMethodID}, + signature::ReturnType, + JNIEnv, +}; + +/// A struct that holds all the JNI methods and fields for JVM `CometBatchIterator` class. +pub struct CometBatchIterator<'a> { + pub class: JClass<'a>, + pub method_next: JMethodID, + pub method_next_ret: ReturnType, +} + +impl<'a> CometBatchIterator<'a> { + pub const JVM_CLASS: &'static str = "org/apache/comet/CometBatchIterator"; + + pub fn new(env: &mut JNIEnv<'a>) -> JniResult> { + // Get the global class reference + let class = get_global_jclass(env, Self::JVM_CLASS)?; + + Ok(CometBatchIterator { + class, + method_next: env.get_method_id(Self::JVM_CLASS, "next", "()[J").unwrap(), + method_next_ret: ReturnType::Array, + }) + } +} diff --git a/core/src/jvm_bridge/mod.rs b/core/src/jvm_bridge/mod.rs index 087096140..7a2882e30 100644 --- a/core/src/jvm_bridge/mod.rs +++ b/core/src/jvm_bridge/mod.rs @@ -17,6 +17,8 @@ //! JNI JVM related functions +use crate::errors::CometResult; + use jni::{ errors::{Error, Result as JniResult}, objects::{JClass, JMethodID, JObject, JString, JThrowable, JValueGen, JValueOwned}, @@ -71,7 +73,7 @@ macro_rules! jni_call { let ret = $env.call_method_unchecked($obj, method_id, ret_type, args); // Check if JVM has thrown any exception, and handle it if so. - let result = if let Some(exception) = $crate::jvm_bridge::check_exception($env).unwrap() { + let result = if let Some(exception) = $crate::jvm_bridge::check_exception($env)? { Err(exception.into()) } else { $crate::jvm_bridge::jni_map_error!($env, ret) @@ -190,11 +192,11 @@ pub fn get_global_jclass(env: &mut JNIEnv, cls: &str) -> JniResult { pub comet_metric_node: CometMetricNode<'a>, /// The static CometExec class. Used for getting the subquery result. pub comet_exec: CometExec<'a>, + /// The CometBatchIterator class. Used for iterating over the batches. + pub comet_batch_iterator: CometBatchIterator<'a>, } unsafe impl<'a> Send for JVMClasses<'a> {} @@ -256,6 +260,7 @@ impl JVMClasses<'_> { throwable_get_cause_method, comet_metric_node: CometMetricNode::new(env).unwrap(), comet_exec: CometExec::new(env).unwrap(), + comet_batch_iterator: CometBatchIterator::new(env).unwrap(), } }); } diff --git a/spark/src/main/java/org/apache/comet/CometBatchIterator.java b/spark/src/main/java/org/apache/comet/CometBatchIterator.java new file mode 100644 index 000000000..33603290c --- /dev/null +++ b/spark/src/main/java/org/apache/comet/CometBatchIterator.java @@ -0,0 +1,56 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.comet; + +import scala.collection.Iterator; + +import org.apache.spark.sql.vectorized.ColumnarBatch; + +import org.apache.comet.vector.NativeUtil; + +/** + * An iterator that can be used to get batches of Arrow arrays from a Spark iterator of + * ColumnarBatch. It will consume input iterator and return Arrow arrays by addresses. This is + * called by native code to retrieve Arrow arrays from Spark through JNI. + */ +public class CometBatchIterator { + final Iterator input; + final NativeUtil nativeUtil; + + CometBatchIterator(Iterator input, NativeUtil nativeUtil) { + this.input = input; + this.nativeUtil = nativeUtil; + } + + /** + * Get the next batches of Arrow arrays. It will consume input iterator and return Arrow arrays by + * addresses. If the input iterator is done, it will return a one negative element array + * indicating the end of the iterator. + */ + public long[] next() { + boolean hasBatch = input.hasNext(); + + if (!hasBatch) { + return new long[] {-1}; + } + + return nativeUtil.exportBatch(input.next()); + } +} diff --git a/spark/src/main/scala/org/apache/comet/CometExecIterator.scala b/spark/src/main/scala/org/apache/comet/CometExecIterator.scala index 01405821f..20b2d384a 100644 --- a/spark/src/main/scala/org/apache/comet/CometExecIterator.scala +++ b/spark/src/main/scala/org/apache/comet/CometExecIterator.scala @@ -48,30 +48,24 @@ class CometExecIterator( extends Iterator[ColumnarBatch] { private val nativeLib = new Native() + private val nativeUtil = new NativeUtil + private val cometBatchIterators = inputs.map { iterator => + new CometBatchIterator(iterator, nativeUtil) + }.toArray private val plan = { val configs = createNativeConf - nativeLib.createPlan(id, configs, protobufQueryPlan, nativeMetrics) + nativeLib.createPlan(id, configs, cometBatchIterators, protobufQueryPlan, nativeMetrics) } - private val nativeUtil = new NativeUtil + private var nextBatch: Option[ColumnarBatch] = None private var currentBatch: ColumnarBatch = null private var closed: Boolean = false - private def peekNext(): ExecutionState = { - convertNativeResult(nativeLib.peekNext(plan)) - } + private def executeNative(): ExecutionState = { + val result = nativeLib.executePlan(plan) - private def executeNative( - input: Array[Array[Long]], - finishes: Array[Boolean], - numRows: Int): ExecutionState = { - convertNativeResult(nativeLib.executePlan(plan, input, finishes, numRows)) - } - - private def convertNativeResult(result: Array[Long]): ExecutionState = { val flag = result(0) if (flag == -1) EOF - else if (flag == 0) Pending else if (flag == 1) { val numRows = result(1) val addresses = result.slice(2, result.length) @@ -113,36 +107,12 @@ class CometExecIterator( /** The execution is finished - no more batch */ case object EOF extends ExecutionState - /** The execution is pending (e.g., blocking operator is still consuming batches) */ - case object Pending extends ExecutionState - - private def peek(): Option[ColumnarBatch] = { - peekNext() match { - case Batch(numRows, addresses) => - val cometVectors = nativeUtil.importVector(addresses) - Some(new ColumnarBatch(cometVectors.toArray, numRows)) - case _ => - None - } - } - - def getNextBatch( - inputArrays: Array[Array[Long]], - finishes: Array[Boolean], - numRows: Int): Option[ColumnarBatch] = { - executeNative(inputArrays, finishes, numRows) match { + def getNextBatch(): Option[ColumnarBatch] = { + executeNative() match { case EOF => None case Batch(numRows, addresses) => val cometVectors = nativeUtil.importVector(addresses) Some(new ColumnarBatch(cometVectors.toArray, numRows)) - case Pending => - if (finishes.forall(_ == true)) { - // Once no input, we should not get a pending flag. - throw new SparkException( - "Native execution should not be pending after reaching end of input batches") - } - // For pending, we keep reading next input. - None } } @@ -152,48 +122,12 @@ class CometExecIterator( if (nextBatch.isDefined) { return true } - // Before we pull next input batch, check if there is next output batch available - // from native side. Some operators might still have output batches ready produced - // from last input batch. For example, `expand` operator will produce output batches - // based on the input batch. - nextBatch = peek() - - // Next input batches are available, execute native query plan with the inputs until - // we get next output batch ready - while (nextBatch.isEmpty && inputs.exists(_.hasNext)) { - val batches = inputs.map { - case input if input.hasNext => Some(input.next()) - case _ => None - } - var numRows = -1 - val (batchAddresses, finishes) = batches - .map { - case Some(batch) => - numRows = batch.numRows() - (nativeUtil.exportBatch(batch), false) - case None => (Array.empty[Long], true) - } - .toArray - .unzip - - // At least one input batch should be consumed - assert(numRows != -1, "No input batch has been consumed") - - nextBatch = getNextBatch(batchAddresses, finishes, numRows) - } + nextBatch = getNextBatch() - // After we consume to the end of the iterators, the native side still can output batches - // back because there might be blocking operators e.g. Sort. We continue ask for batches - // until it returns empty columns. if (nextBatch.isEmpty) { - val finishes = inputs.map(_ => true).toArray - nextBatch = getNextBatch(inputs.map(_ => Array.empty[Long]).toArray, finishes, 0) - val hasNext = nextBatch.isDefined - if (!hasNext) { - close() - } - hasNext + close() + false } else { true } @@ -222,6 +156,7 @@ class CometExecIterator( currentBatch = null } nativeLib.releasePlan(plan) + // The allocator thoughts the exported ArrowArray and ArrowSchema structs are not released, // so it will report: // Caused by: java.lang.IllegalStateException: Memory was leaked by query. diff --git a/spark/src/main/scala/org/apache/comet/Native.scala b/spark/src/main/scala/org/apache/comet/Native.scala index 05bada522..8c1b8ac22 100644 --- a/spark/src/main/scala/org/apache/comet/Native.scala +++ b/spark/src/main/scala/org/apache/comet/Native.scala @@ -31,6 +31,9 @@ class Native extends NativeBase { * The id of the query plan. * @param configMap * The Java Map object for the configs of native engine. + * @param iterators + * the input iterators to the native query plan. It should be the same number as the number of + * scan nodes in the SparkPlan. * @param plan * the bytes of serialized SparkPlan. * @param metrics @@ -41,61 +44,21 @@ class Native extends NativeBase { @native def createPlan( id: Long, configMap: Map[String, String], + iterators: Array[CometBatchIterator], plan: Array[Byte], metrics: CometMetricNode): Long - /** - * Return the native query plan string for the given address of native query plan. For debugging - * purpose. - * - * @param plan - * the address to native query plan. - * @return - * the string of native query plan. - */ - @native def getPlanString(plan: Long): String - /** * Execute a native query plan based on given input Arrow arrays. * * @param plan * the address to native query plan. - * @param addresses - * the array of addresses of input Arrow arrays. The addresses are exported from Arrow Arrays - * so the number of addresses is always even number in the sequence like [array_address1, - * schema_address1, array_address2, schema_address2, ...]. Note that we can pass empty - * addresses to this API. In this case, it indicates there are no more input arrays to the - * native query plan, but the query plan possibly can still execute to produce output batch - * because it might contain blocking operators such as Sort, Aggregate. When this API returns - * an empty array back, it means the native query plan is finished. - * @param finishes - * whether the end of input arrays is reached for each input. If this is set to true, the - * native library will know there is no more inputs. But it doesn't mean the execution is - * finished immediately. For some blocking operators native execution will continue to output. - * @param numRows - * the number of rows in the batch. - * @return - * an array containing: 1) the status flag (0 for pending, 1 for normal returned arrays, - * -1 for end of output), 2) (optional) the number of rows if returned flag is 1 3) the - * addresses of output Arrow arrays - */ - @native def executePlan( - plan: Long, - addresses: Array[Array[Long]], - finishes: Array[Boolean], - numRows: Int): Array[Long] - - /** - * Peeks the next batch of output Arrow arrays from the native query plan without pulling any - * input batches. - * - * @param plan - * the address to native query plan. * @return - * an array containing: 1) the status flag (0 for pending, 1 for normal returned arrays, 2) - * (optional) the number of rows if returned flag is 1 3) the addresses of output Arrow arrays + * an array containing: 1) the status flag (1 for normal returned arrays, -1 for end of + * output) 2) (optional) the number of rows if returned flag is 1 3) the addresses of output + * Arrow arrays */ - @native def peekNext(plan: Long): Array[Long] + @native def executePlan(plan: Long): Array[Long] /** * Release and drop the native query plan object and context object. diff --git a/spark/src/test/scala/org/apache/comet/parquet/ParquetReadSuite.scala b/spark/src/test/scala/org/apache/comet/parquet/ParquetReadSuite.scala index 1cff74d39..f44752297 100644 --- a/spark/src/test/scala/org/apache/comet/parquet/ParquetReadSuite.scala +++ b/spark/src/test/scala/org/apache/comet/parquet/ParquetReadSuite.scala @@ -42,7 +42,6 @@ import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.comet.CometBatchScanExec import org.apache.spark.sql.comet.CometScanExec import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper -import org.apache.spark.sql.execution.datasources.SchemaColumnConvertNotSupportedException import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String @@ -1139,7 +1138,7 @@ abstract class ParquetReadSuite extends CometTestBase { .where(s"a < ${Long.MaxValue}") .collect() } - assert(exception.getCause.getCause.isInstanceOf[SchemaColumnConvertNotSupportedException]) + assert(exception.getMessage.contains("Column: [a], Expected: bigint, Found: INT32")) } }