Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Pull based native execution #69

Merged
merged 1 commit into from
Feb 21, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 9 additions & 7 deletions common/src/main/scala/org/apache/comet/vector/NativeUtil.scala
Original file line number Diff line number Diff line change
Expand Up @@ -39,11 +39,14 @@ class NativeUtil {
* @param batch
* the input Comet columnar batch
* @return
* a list containing pairs of memory addresses in the format of (address of Arrow array,
* address of Arrow schema)
* a list containing number of rows + pairs of memory addresses in the format of (address of
* Arrow array, address of Arrow schema)
*/
def exportBatch(batch: ColumnarBatch): Array[Long] = {
val vectors = (0 until batch.numCols()).flatMap { index =>
val exportedVectors = mutable.ArrayBuffer.empty[Long]
exportedVectors += batch.numRows()

(0 until batch.numCols()).foreach { index =>
batch.column(index) match {
case a: CometVector =>
val valueVector = a.getValueVector
Expand All @@ -63,17 +66,16 @@ class NativeUtil {
arrowArray,
arrowSchema)

Seq((arrowArray, arrowSchema))
exportedVectors += arrowArray.memoryAddress()
exportedVectors += arrowSchema.memoryAddress()
case c =>
throw new SparkException(
"Comet execution only takes Arrow Arrays, but got " +
s"${c.getClass}")
}
}

vectors.flatMap { pair =>
Seq(pair._1.memoryAddress(), pair._2.memoryAddress())
}.toArray
exportedVectors.toArray
}

/**
Expand Down
9 changes: 9 additions & 0 deletions core/src/errors.rs
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,15 @@ impl From<CometError> for DataFusionError {
}
}

impl From<CometError> for ExecutionError {
fn from(value: CometError) -> Self {
match value {
CometError::Execution { source } => source,
_ => ExecutionError::GeneralError(value.to_string()),
}
}
}

impl jni::errors::ToException for CometError {
fn to_exception(&self) -> Exception {
match self {
Expand Down
83 changes: 57 additions & 26 deletions core/src/execution/datafusion/planner.rs
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ use datafusion_physical_expr::{
AggregateExpr, ScalarFunctionExpr,
};
use itertools::Itertools;
use jni::objects::GlobalRef;
use num::{BigInt, ToPrimitive};

use crate::{
Expand All @@ -70,7 +71,7 @@ use crate::{
operators::expand::CometExpandExec,
shuffle_writer::ShuffleWriterExec,
},
operators::{CopyExec, ExecutionError, InputBatch, ScanExec},
operators::{CopyExec, ExecutionError, ScanExec},
serde::to_arrow_datatype,
spark_expression,
spark_expression::{
Expand All @@ -88,6 +89,8 @@ type PhyAggResult = Result<Vec<Arc<dyn AggregateExpr>>, ExecutionError>;
type PhyExprResult = Result<Vec<(Arc<dyn PhysicalExpr>, String)>, ExecutionError>;
type PartitionPhyExprResult = Result<Vec<Arc<dyn PhysicalExpr>>, ExecutionError>;

pub const TEST_EXEC_CONTEXT_ID: i64 = -1;

/// The query planner for converting Spark query plans to DataFusion query plans.
pub struct PhysicalPlanner {
// The execution context id of this planner.
Expand All @@ -105,7 +108,7 @@ impl PhysicalPlanner {
pub fn new() -> Self {
let execution_props = ExecutionProps::new();
Self {
exec_context_id: -1,
exec_context_id: TEST_EXEC_CONTEXT_ID,
execution_props,
}
}
Expand Down Expand Up @@ -612,24 +615,28 @@ impl PhysicalPlanner {

/// Create a DataFusion physical plan from Spark physical plan.
///
/// Note that we need `input_batches` parameter because we need to know the exact schema (not
/// only data type but also dictionary-encoding) at `ScanExec`s. It is because some DataFusion
/// operators, e.g., `ProjectionExec`, gets child operator schema during initialization and
/// uses it later for `RecordBatch`. We may be able to get rid of it once `RecordBatch`
/// relaxes schema check.
/// `inputs` is a vector of input source IDs. It is used to create `ScanExec`s. Each `ScanExec`
/// will be assigned a unique ID from `inputs` and the ID will be used to identify the input
/// source at JNI API.
///
/// Note that `ScanExec` will pull initial input batch during initialization. It is because we
/// need to know the exact schema (not only data type but also dictionary-encoding) at
/// `ScanExec`s. It is because some DataFusion operators, e.g., `ProjectionExec`, gets child
/// operator schema during initialization and uses it later for `RecordBatch`. We may be
/// able to get rid of it once `RecordBatch` relaxes schema check.
///
/// Note that we return created `Scan`s which will be kept at JNI API. JNI calls will use it to
/// feed in new input batch from Spark JVM side.
pub fn create_plan<'a>(
&'a self,
spark_plan: &'a Operator,
input_batches: &mut Vec<InputBatch>,
inputs: &mut Vec<Arc<GlobalRef>>,
) -> Result<(Vec<ScanExec>, Arc<dyn ExecutionPlan>), ExecutionError> {
let children = &spark_plan.children;
match spark_plan.op_struct.as_ref().unwrap() {
OpStruct::Projection(project) => {
assert!(children.len() == 1);
let (scans, child) = self.create_plan(&children[0], input_batches)?;
let (scans, child) = self.create_plan(&children[0], inputs)?;
let exprs: PhyExprResult = project
.project_list
.iter()
Expand All @@ -643,15 +650,15 @@ impl PhysicalPlanner {
}
OpStruct::Filter(filter) => {
assert!(children.len() == 1);
let (scans, child) = self.create_plan(&children[0], input_batches)?;
let (scans, child) = self.create_plan(&children[0], inputs)?;
let predicate =
self.create_expr(filter.predicate.as_ref().unwrap(), child.schema())?;

Ok((scans, Arc::new(FilterExec::try_new(predicate, child)?)))
}
OpStruct::HashAgg(agg) => {
assert!(children.len() == 1);
let (scans, child) = self.create_plan(&children[0], input_batches)?;
let (scans, child) = self.create_plan(&children[0], inputs)?;

let group_exprs: PhyExprResult = agg
.grouping_exprs
Expand Down Expand Up @@ -716,13 +723,13 @@ impl PhysicalPlanner {
OpStruct::Limit(limit) => {
assert!(children.len() == 1);
let num = limit.limit;
let (scans, child) = self.create_plan(&children[0], input_batches)?;
let (scans, child) = self.create_plan(&children[0], inputs)?;

Ok((scans, Arc::new(LocalLimitExec::new(child, num as usize))))
}
OpStruct::Sort(sort) => {
assert!(children.len() == 1);
let (scans, child) = self.create_plan(&children[0], input_batches)?;
let (scans, child) = self.create_plan(&children[0], inputs)?;

let exprs: Result<Vec<PhysicalSortExpr>, ExecutionError> = sort
.sort_orders
Expand All @@ -741,21 +748,32 @@ impl PhysicalPlanner {
}
OpStruct::Scan(scan) => {
let fields = scan.fields.iter().map(to_arrow_datatype).collect_vec();
if input_batches.is_empty() {

// If it is not test execution context for unit test, we should have at least one
// input source
if self.exec_context_id != TEST_EXEC_CONTEXT_ID && inputs.is_empty() {
return Err(ExecutionError::GeneralError(
"No input batch for scan".to_string(),
"No input for scan".to_string(),
));
}
// Consumes the first input batch source for the scan
let input_batch = input_batches.remove(0);

// Consumes the first input source for the scan
let input_source = if self.exec_context_id == TEST_EXEC_CONTEXT_ID
&& inputs.is_empty()
{
// For unit test, we will set input batch to scan directly by `set_input_batch`.
None
} else {
Some(inputs.remove(0))
};

// The `ScanExec` operator will take actual arrays from Spark during execution
let scan = ScanExec::new(input_batch, fields);
let scan = ScanExec::new(self.exec_context_id, input_source, fields)?;
Ok((vec![scan.clone()], Arc::new(scan)))
}
OpStruct::ShuffleWriter(writer) => {
assert!(children.len() == 1);
let (scans, child) = self.create_plan(&children[0], input_batches)?;
let (scans, child) = self.create_plan(&children[0], inputs)?;

let partitioning = self
.create_partitioning(writer.partitioning.as_ref().unwrap(), child.schema())?;
Expand All @@ -772,7 +790,7 @@ impl PhysicalPlanner {
}
OpStruct::Expand(expand) => {
assert!(children.len() == 1);
let (scans, child) = self.create_plan(&children[0], input_batches)?;
let (scans, child) = self.create_plan(&children[0], inputs)?;

let mut projections = vec![];
let mut projection = vec![];
Expand Down Expand Up @@ -805,6 +823,18 @@ impl PhysicalPlanner {
.collect();
let schema = Arc::new(Schema::new(fields));

// `Expand` operator keeps the input batch and expands it to multiple output
// batches. However, `ScanExec` will reuse input arrays for the next
// input batch. Therefore, we need to copy the input batch to avoid
// the data corruption. Note that we only need to copy the input batch
// if the child operator is `ScanExec`, because other operators after `ScanExec`
// will create new arrays for the output batch.
let child = if child.as_any().downcast_ref::<ScanExec>().is_some() {
Arc::new(CopyExec::new(child))
} else {
child
};

Ok((
scans,
Arc::new(CometExpandExec::new(projections, child, schema)),
Expand Down Expand Up @@ -997,9 +1027,9 @@ mod tests {
let values = Int32Array::from(vec![0, 1, 2, 3]);
let input_array = DictionaryArray::new(keys, Arc::new(values));
let input_batch = InputBatch::Batch(vec![Arc::new(input_array)], row_count);
let mut input_batches = vec![input_batch];

let (mut scans, datafusion_plan) = planner.create_plan(&op, &mut input_batches).unwrap();
let (mut scans, datafusion_plan) = planner.create_plan(&op, &mut vec![]).unwrap();
scans[0].set_input_batch(input_batch);

let session_ctx = SessionContext::new();
let task_ctx = session_ctx.task_ctx();
Expand Down Expand Up @@ -1077,9 +1107,11 @@ mod tests {
let values = StringArray::from(vec!["foo", "bar", "hello", "comet"]);
let input_array = DictionaryArray::new(keys, Arc::new(values));
let input_batch = InputBatch::Batch(vec![Arc::new(input_array)], row_count);
let mut input_batches = vec![input_batch];

let (mut scans, datafusion_plan) = planner.create_plan(&op, &mut input_batches).unwrap();
let (mut scans, datafusion_plan) = planner.create_plan(&op, &mut vec![]).unwrap();

// Scan's schema is determined by the input batch, so we need to set it before execution.
scans[0].set_input_batch(input_batch);

let session_ctx = SessionContext::new();
let task_ctx = session_ctx.task_ctx();
Expand Down Expand Up @@ -1147,8 +1179,7 @@ mod tests {
let op = create_filter(op_scan, 0);
let planner = PhysicalPlanner::new();

let mut input_batches = vec![InputBatch::EOF];
let (mut scans, datafusion_plan) = planner.create_plan(&op, &mut input_batches).unwrap();
let (mut scans, datafusion_plan) = planner.create_plan(&op, &mut vec![]).unwrap();

let scan = &mut scans[0];
scan.set_input_batch(InputBatch::EOF);
Expand Down
Loading