Skip to content

Commit

Permalink
feat: Support sort merge join
Browse files Browse the repository at this point in the history
  • Loading branch information
viirya committed Mar 9, 2024
1 parent 488c523 commit 0b7f600
Show file tree
Hide file tree
Showing 8 changed files with 296 additions and 14 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,11 @@ public UTF8String getUTF8String(int rowId) {
byte[] result = new byte[length];
Platform.copyMemory(
null, valueBufferAddress + offset, result, Platform.BYTE_ARRAY_OFFSET, length);
return UTF8String.fromString(convertToUuid(result).toString());
if (length == 16) {
return UTF8String.fromString(convertToUuid(result).toString());
} else {
return UTF8String.fromBytes(result);
}
}
}

Expand Down
93 changes: 91 additions & 2 deletions core/src/execution/datafusion/planner.rs
Original file line number Diff line number Diff line change
Expand Up @@ -32,13 +32,14 @@ use datafusion::{
physical_plan::{
aggregates::{AggregateMode as DFAggregateMode, PhysicalGroupBy},
filter::FilterExec,
joins::SortMergeJoinExec,
limit::LocalLimitExec,
projection::ProjectionExec,
sorts::sort::SortExec,
ExecutionPlan, Partitioning,
},
};
use datafusion_common::ScalarValue;
use datafusion_common::{JoinType as DFJoinType, ScalarValue};
use datafusion_physical_expr::{
execution_props::ExecutionProps,
expressions::{
Expand Down Expand Up @@ -79,7 +80,7 @@ use crate::{
agg_expr::ExprStruct as AggExprStruct, expr::ExprStruct, literal::Value, AggExpr, Expr,
ScalarFunc,
},
spark_operator::{operator::OpStruct, Operator},
spark_operator::{operator::OpStruct, JoinType, Operator},
spark_partitioning::{partitioning::PartitioningStruct, Partitioning as SparkPartitioning},
},
};
Expand Down Expand Up @@ -849,6 +850,85 @@ impl PhysicalPlanner {
Arc::new(CometExpandExec::new(projections, child, schema)),
))
}
OpStruct::SortMergeJoin(join) => {
assert!(children.len() == 2);
let (mut left_scans, left) = self.create_plan(&children[0], inputs)?;
let (mut right_scans, right) = self.create_plan(&children[1], inputs)?;

left_scans.append(&mut right_scans);

let left_join_exprs = join
.left_join_keys
.iter()
.map(|expr| self.create_expr(expr, left.schema()))
.collect::<Result<Vec<_>, _>>()?;
let right_join_exprs = join
.right_join_keys
.iter()
.map(|expr| self.create_expr(expr, right.schema()))
.collect::<Result<Vec<_>, _>>()?;

let join_on = left_join_exprs
.into_iter()
.zip(right_join_exprs)
.collect::<Vec<_>>();

let join_type = match join.join_type.try_into() {
Ok(JoinType::Inner) => DFJoinType::Inner,
Ok(JoinType::LeftOuter) => DFJoinType::Left,
Ok(JoinType::RightOuter) => DFJoinType::Right,
Ok(JoinType::FullOuter) => DFJoinType::Full,
Ok(JoinType::LeftSemi) => DFJoinType::LeftSemi,
Ok(JoinType::RightSemi) => DFJoinType::RightSemi,
Ok(JoinType::LeftAnti) => DFJoinType::LeftAnti,
Ok(JoinType::RightAnti) => DFJoinType::RightAnti,
Err(_) => {
return Err(ExecutionError::GeneralError(format!(
"Unsupported join type: {:?}",
join.join_type
)));
}
};

let sort_options = join
.sort_options
.iter()
.map(|sort_option| {
let sort_expr = self.create_sort_expr(sort_option, left.schema()).unwrap();
SortOptions {
descending: sort_expr.options.descending,
nulls_first: sort_expr.options.nulls_first,
}
})
.collect();

// DataFusion `SortMergeJoinExec` operator keeps the input batch internally. We need
// to copy the input batch to avoid the data corruption from reusing the input
// batch.
let left = if op_reuse_array(&left) {
Arc::new(CopyExec::new(left))
} else {
left
};

let right = if op_reuse_array(&right) {
Arc::new(CopyExec::new(right))
} else {
right
};

let join = Arc::new(SortMergeJoinExec::try_new(
left,
right,
join_on,
None,
join_type,
sort_options,
false,
)?);

Ok((left_scans, join))
}
}
}

Expand Down Expand Up @@ -1017,6 +1097,15 @@ impl From<ExpressionError> for DataFusionError {
}
}

/// Returns true if given operator probably returns input array as output array without
/// modification.
fn op_reuse_array(op: &Arc<dyn ExecutionPlan>) -> bool {
op.as_any().downcast_ref::<ScanExec>().is_some()
|| op.as_any().downcast_ref::<LocalLimitExec>().is_some()
|| op.as_any().downcast_ref::<ProjectionExec>().is_some()
|| op.as_any().downcast_ref::<FilterExec>().is_some()
}

#[cfg(test)]
mod tests {
use std::{sync::Arc, task::Poll};
Expand Down
2 changes: 1 addition & 1 deletion core/src/execution/operators/copy.rs
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ impl ExecutionPlan for CopyExec {
}

fn children(&self) -> Vec<Arc<dyn ExecutionPlan>> {
self.input.children()
vec![self.input.clone()]
}

fn with_new_children(
Expand Down
19 changes: 19 additions & 0 deletions core/src/execution/proto/operator.proto
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ message Operator {
Limit limit = 105;
ShuffleWriter shuffle_writer = 106;
Expand expand = 107;
SortMergeJoin sort_merge_join = 108;
}
}

Expand Down Expand Up @@ -87,3 +88,21 @@ message Expand {
repeated spark.spark_expression.Expr project_list = 1;
int32 num_expr_per_project = 3;
}

message SortMergeJoin {
repeated spark.spark_expression.Expr left_join_keys = 1;
repeated spark.spark_expression.Expr right_join_keys = 2;
JoinType join_type = 3;
repeated spark.spark_expression.Expr sort_options = 4;
}

enum JoinType {
Inner = 0;
LeftOuter = 1;
RightOuter = 2;
FullOuter = 3;
LeftSemi = 4;
RightSemi = 5;
LeftAnti = 6;
RightAnti = 7;
}
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@ import org.apache.spark.internal.Logging
import org.apache.spark.network.util.ByteUnit
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.SparkSessionExtensions
import org.apache.spark.sql.catalyst.expressions.AttributeReference
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.comet._
import org.apache.spark.sql.comet.execution.shuffle.{CometColumnarShuffle, CometNativeShuffle}
Expand All @@ -38,6 +37,7 @@ import org.apache.spark.sql.execution.datasources.parquet.ParquetFileFormat
import org.apache.spark.sql.execution.datasources.v2.BatchScanExec
import org.apache.spark.sql.execution.datasources.v2.parquet.ParquetScan
import org.apache.spark.sql.execution.exchange.{BroadcastExchangeExec, ShuffleExchangeExec}
import org.apache.spark.sql.execution.joins.SortMergeJoinExec
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types._

Expand Down Expand Up @@ -222,12 +222,16 @@ class CometSparkSessionExtensions
*/
// spotless:on
private def transform(plan: SparkPlan): SparkPlan = {
def transform1(op: UnaryExecNode): Option[Operator] = {
op.child match {
case childNativeOp: CometNativeExec =>
QueryPlanSerde.operator2Proto(op, childNativeOp.nativeOp)
case _ =>
None
def transform1(op: SparkPlan): Option[Operator] = {
val allNativeExec = op.children.map {
case childNativeOp: CometNativeExec => Some(childNativeOp.nativeOp)
case _ => None
}

if (allNativeExec.forall(_.isDefined)) {
QueryPlanSerde.operator2Proto(op, allNativeExec.map(_.get): _*)
} else {
None
}
}

Expand Down Expand Up @@ -333,6 +337,26 @@ class CometSparkSessionExtensions
op
}

case op: SortMergeJoinExec
if isCometOperatorEnabled(conf, "sort_merge_join") &&
op.children.forall(isCometNative(_)) =>
val newOp = transform1(op)
newOp match {
case Some(nativeOp) =>
CometSortMergeJoinExec(
nativeOp,
op,
op.leftKeys,
op.rightKeys,
op.joinType,
op.condition,
op.left,
op.right,
SerializedPlan(None))
case None =>
op
}

case c @ CoalesceExec(numPartitions, child)
if isCometOperatorEnabled(conf, "coalesce")
&& isCometNative(child) =>
Expand Down Expand Up @@ -547,7 +571,9 @@ object CometSparkSessionExtensions extends Logging {

private[comet] def isCometOperatorEnabled(conf: SQLConf, operator: String): Boolean = {
val operatorFlag = s"$COMET_EXEC_CONFIG_PREFIX.$operator.enabled"
conf.getConfString(operatorFlag, "false").toBoolean || isCometAllOperatorEnabled(conf)
val operatorDisabledFlag = s"$COMET_EXEC_CONFIG_PREFIX.$operator.disabled"
conf.getConfString(operatorFlag, "false").toBoolean || isCometAllOperatorEnabled(conf) &&
!conf.getConfString(operatorDisabledFlag, "false").toBoolean
}

private[comet] def isCometBroadCastEnabled(conf: SQLConf): Boolean = {
Expand Down
60 changes: 59 additions & 1 deletion spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala
Original file line number Diff line number Diff line change
Expand Up @@ -26,21 +26,23 @@ import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, Average, Count, Final, First, Last, Max, Min, Partial, Sum}
import org.apache.spark.sql.catalyst.expressions.objects.StaticInvoke
import org.apache.spark.sql.catalyst.optimizer.NormalizeNaNAndZero
import org.apache.spark.sql.catalyst.plans._
import org.apache.spark.sql.catalyst.plans.physical.{HashPartitioning, Partitioning, SinglePartition}
import org.apache.spark.sql.catalyst.util.CharVarcharCodegenUtils
import org.apache.spark.sql.comet.{CometHashAggregateExec, CometPlan, CometSinkPlaceHolder, DecimalPrecision}
import org.apache.spark.sql.execution
import org.apache.spark.sql.execution._
import org.apache.spark.sql.execution.aggregate.HashAggregateExec
import org.apache.spark.sql.execution.exchange.{BroadcastExchangeExec, ShuffleExchangeExec}
import org.apache.spark.sql.execution.joins.SortMergeJoinExec
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.UTF8String

import org.apache.comet.CometSparkSessionExtensions.{isCometOperatorEnabled, isCometScan, isSpark32, isSpark34Plus}
import org.apache.comet.serde.ExprOuterClass.{AggExpr, DataType => ProtoDataType, Expr, ScalarFunc}
import org.apache.comet.serde.ExprOuterClass.DataType.{DataTypeInfo, DecimalInfo, ListInfo, MapInfo, StructInfo}
import org.apache.comet.serde.OperatorOuterClass.{AggregateMode => CometAggregateMode, Operator}
import org.apache.comet.serde.OperatorOuterClass.{AggregateMode => CometAggregateMode, JoinType, Operator}
import org.apache.comet.shims.ShimQueryPlanSerde

/**
Expand Down Expand Up @@ -1836,6 +1838,62 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde {
}
}

case join: SortMergeJoinExec if isCometOperatorEnabled(op.conf, "sort_merge_join") =>
// `requiredOrders` and `getKeyOrdering` are copied from Spark's SortMergeJoinExec.
def requiredOrders(keys: Seq[Expression]): Seq[SortOrder] = {
keys.map(SortOrder(_, Ascending))
}

def getKeyOrdering(
keys: Seq[Expression],
childOutputOrdering: Seq[SortOrder]): Seq[SortOrder] = {
val requiredOrdering = requiredOrders(keys)
if (SortOrder.orderingSatisfies(childOutputOrdering, requiredOrdering)) {
keys.zip(childOutputOrdering).map { case (key, childOrder) =>
val sameOrderExpressionsSet = ExpressionSet(childOrder.children) - key
SortOrder(key, Ascending, sameOrderExpressionsSet.toSeq)
}
} else {
requiredOrdering
}
}

// TODO: Support SortMergeJoin with join condition after new DataFusion release
if (join.condition.isDefined) {
return None
}

val joinType = join.joinType match {
case Inner => JoinType.Inner
case LeftOuter => JoinType.LeftOuter
case RightOuter => JoinType.RightOuter
case FullOuter => JoinType.FullOuter
case LeftSemi => JoinType.LeftSemi
case LeftAnti => JoinType.LeftAnti
case _ => return None // Spark doesn't support other join types
}

val leftKeys = join.leftKeys.map(exprToProto(_, join.left.output))
val rightKeys = join.rightKeys.map(exprToProto(_, join.right.output))

val sortOptions = getKeyOrdering(join.leftKeys, join.left.outputOrdering)
.map(exprToProto(_, join.left.output))

if (sortOptions.forall(_.isDefined) &&
leftKeys.forall(_.isDefined) &&
rightKeys.forall(_.isDefined) &&
childOp.nonEmpty) {
val joinBuilder = OperatorOuterClass.SortMergeJoin
.newBuilder()
.setJoinType(joinType)
.addAllSortOptions(sortOptions.map(_.get).asJava)
.addAllLeftJoinKeys(leftKeys.map(_.get).asJava)
.addAllRightJoinKeys(rightKeys.map(_.get).asJava)
Some(result.setSortMergeJoin(joinBuilder).build())
} else {
None
}

case op if isCometSink(op) =>
// These operators are source of Comet native execution chain
val scanBuilder = OperatorOuterClass.Scan.newBuilder()
Expand Down
Loading

0 comments on commit 0b7f600

Please sign in to comment.