Skip to content

Commit

Permalink
Merge remote-tracking branch 'apache/comet-parquet-exec' into schema-…
Browse files Browse the repository at this point in the history
…adapter-fixes
  • Loading branch information
andygrove committed Dec 5, 2024
2 parents b6036f2 + e0d8077 commit 0b43b23
Show file tree
Hide file tree
Showing 3 changed files with 77 additions and 64 deletions.
88 changes: 37 additions & 51 deletions native/core/src/execution/datafusion/planner.rs
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,6 @@ use datafusion_physical_expr::LexOrdering;
use itertools::Itertools;
use jni::objects::GlobalRef;
use num::{BigInt, ToPrimitive};
use parquet::schema::parser::parse_message_type;
use std::cmp::max;
use std::{collections::HashMap, sync::Arc};
use url::Url;
Expand Down Expand Up @@ -950,50 +949,28 @@ impl PhysicalPlanner {
))
}
OpStruct::NativeScan(scan) => {
let data_schema = parse_message_type(&scan.data_schema).unwrap();
let required_schema = parse_message_type(&scan.required_schema).unwrap();

let data_schema_descriptor =
parquet::schema::types::SchemaDescriptor::new(Arc::new(data_schema));
let data_schema_arrow = Arc::new(
parquet::arrow::schema::parquet_to_arrow_schema(&data_schema_descriptor, None)
.unwrap(),
);

let required_schema_descriptor =
parquet::schema::types::SchemaDescriptor::new(Arc::new(required_schema));
let required_schema_arrow = Arc::new(
parquet::arrow::schema::parquet_to_arrow_schema(
&required_schema_descriptor,
None,
)
.unwrap(),
);

let partition_schema_arrow = scan
.partition_schema
let data_schema = convert_spark_types_to_arrow_schema(scan.data_schema.as_slice());
let required_schema: SchemaRef =
convert_spark_types_to_arrow_schema(scan.required_schema.as_slice());
let partition_schema: SchemaRef =
convert_spark_types_to_arrow_schema(scan.partition_schema.as_slice());
let projection_vector: Vec<usize> = scan
.projection_vector
.iter()
.map(to_arrow_datatype)
.collect_vec();
let partition_fields: Vec<_> = partition_schema_arrow
.iter()
.enumerate()
.map(|(idx, data_type)| {
Field::new(format!("part_{}", idx), data_type.clone(), true)
})
.map(|offset| *offset as usize)
.collect();

// Convert the Spark expressions to Physical expressions
let data_filters: Result<Vec<Arc<dyn PhysicalExpr>>, ExecutionError> = scan
.data_filters
.iter()
.map(|expr| self.create_expr(expr, Arc::clone(&required_schema_arrow)))
.map(|expr| self.create_expr(expr, Arc::clone(&required_schema)))
.collect();

// Create a conjunctive form of the vector because ParquetExecBuilder takes
// a single expression
let data_filters = data_filters?;
let test_data_filters = data_filters.clone().into_iter().reduce(|left, right| {
let cnf_data_filters = data_filters.clone().into_iter().reduce(|left, right| {
Arc::new(BinaryExpr::new(
left,
datafusion::logical_expr::Operator::And,
Expand Down Expand Up @@ -1064,29 +1041,21 @@ impl PhysicalPlanner {
assert_eq!(file_groups.len(), partition_count);

let object_store_url = ObjectStoreUrl::local_filesystem();
let partition_fields: Vec<Field> = partition_schema
.fields()
.iter()
.map(|field| {
Field::new(field.name(), field.data_type().clone(), field.is_nullable())
})
.collect_vec();
let mut file_scan_config =
FileScanConfig::new(object_store_url, Arc::clone(&data_schema_arrow))
FileScanConfig::new(object_store_url, Arc::clone(&data_schema))
.with_file_groups(file_groups)
.with_table_partition_cols(partition_fields);

// Check for projection, if so generate the vector and add to FileScanConfig.
let mut projection_vector: Vec<usize> =
Vec::with_capacity(required_schema_arrow.fields.len());
// TODO: could be faster with a hashmap rather than iterating over data_schema_arrow with index_of.
required_schema_arrow.fields.iter().for_each(|field| {
projection_vector.push(data_schema_arrow.index_of(field.name()).unwrap());
});

partition_schema_arrow
.iter()
.enumerate()
.for_each(|(idx, _)| {
projection_vector.push(idx + data_schema_arrow.fields.len());
});

assert_eq!(
projection_vector.len(),
required_schema_arrow.fields.len() + partition_schema_arrow.len()
required_schema.fields.len() + partition_schema.fields.len()
);
file_scan_config = file_scan_config.with_projection(Some(projection_vector));

Expand All @@ -1099,7 +1068,7 @@ impl PhysicalPlanner {
.with_table_parquet_options(table_parquet_options)
.with_schema_adapter_factory(Arc::new(CometSchemaAdapterFactory::default()));

if let Some(filter) = test_data_filters {
if let Some(filter) = cnf_data_filters {
builder = builder.with_predicate(filter);
}

Expand Down Expand Up @@ -2307,6 +2276,23 @@ fn from_protobuf_eval_mode(value: i32) -> Result<EvalMode, prost::DecodeError> {
}
}

fn convert_spark_types_to_arrow_schema(
spark_types: &[spark_operator::SparkStructField],
) -> SchemaRef {
let arrow_fields = spark_types
.iter()
.map(|spark_type| {
Field::new(
String::clone(&spark_type.name),
to_arrow_datatype(spark_type.data_type.as_ref().unwrap()),
spark_type.nullable,
)
})
.collect_vec();
let arrow_schema: SchemaRef = Arc::new(Schema::new(arrow_fields));
arrow_schema
}

#[cfg(test)]
mod tests {
use std::{sync::Arc, task::Poll};
Expand Down
13 changes: 10 additions & 3 deletions native/proto/src/proto/operator.proto
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,12 @@ message SparkFilePartition {
repeated SparkPartitionedFile partitioned_file = 1;
}

message SparkStructField {
string name = 1;
spark.spark_expression.DataType data_type = 2;
bool nullable = 3;
}

message Scan {
repeated spark.spark_expression.DataType fields = 1;
// The source of the scan (e.g. file scan, broadcast exchange, shuffle, etc). This
Expand All @@ -75,11 +81,12 @@ message NativeScan {
// is purely for informational purposes when viewing native query plans in
// debug mode.
string source = 2;
string required_schema = 3;
string data_schema = 4;
repeated spark.spark_expression.DataType partition_schema = 5;
repeated SparkStructField required_schema = 3;
repeated SparkStructField data_schema = 4;
repeated SparkStructField partition_schema = 5;
repeated spark.spark_expression.Expr data_filters = 6;
repeated SparkFilePartition file_partitions = 7;
repeated int64 projection_vector = 8;
}

message Projection {
Expand Down
40 changes: 30 additions & 10 deletions spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,6 @@ import org.apache.spark.sql.execution._
import org.apache.spark.sql.execution.adaptive.{BroadcastQueryStageExec, ShuffleQueryStageExec}
import org.apache.spark.sql.execution.aggregate.{BaseAggregateExec, HashAggregateExec, ObjectHashAggregateExec}
import org.apache.spark.sql.execution.datasources.{FilePartition, FileScanRDD}
import org.apache.spark.sql.execution.datasources.parquet.SparkToParquetSchemaConverter
import org.apache.spark.sql.execution.datasources.v2.{DataSourceRDD, DataSourceRDDPartition}
import org.apache.spark.sql.execution.exchange.{BroadcastExchangeExec, ReusedExchangeExec, ShuffleExchangeExec}
import org.apache.spark.sql.execution.joins.{BroadcastHashJoinExec, HashJoin, ShuffledHashJoinExec, SortMergeJoinExec}
Expand Down Expand Up @@ -2520,18 +2519,28 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde with CometExprShim
case _ =>
}

val requiredSchemaParquet =
new SparkToParquetSchemaConverter(conf).convert(scan.requiredSchema)
val dataSchemaParquet =
new SparkToParquetSchemaConverter(conf).convert(scan.relation.dataSchema)
val partitionSchema = scan.relation.partitionSchema.fields.flatMap { field =>
serializeDataType(field.dataType)
}
val partitionSchema = schema2Proto(scan.relation.partitionSchema.fields)
val requiredSchema = schema2Proto(scan.requiredSchema.fields)
val dataSchema = schema2Proto(scan.relation.dataSchema.fields)

val data_schema_idxs = scan.requiredSchema.fields.map(field => {
scan.relation.dataSchema.fieldIndex(field.name)
})
val partition_schema_idxs = Array
.range(
scan.relation.dataSchema.fields.length,
scan.relation.dataSchema.length + scan.relation.partitionSchema.fields.length)

val projection_vector = (data_schema_idxs ++ partition_schema_idxs).map(idx =>
idx.toLong.asInstanceOf[java.lang.Long])

nativeScanBuilder.addAllProjectionVector(projection_vector.toIterable.asJava)

// In `CometScanRule`, we ensure partitionSchema is supported.
assert(partitionSchema.length == scan.relation.partitionSchema.fields.length)

nativeScanBuilder.setRequiredSchema(requiredSchemaParquet.toString)
nativeScanBuilder.setDataSchema(dataSchemaParquet.toString)
nativeScanBuilder.addAllDataSchema(dataSchema.toIterable.asJava)
nativeScanBuilder.addAllRequiredSchema(requiredSchema.toIterable.asJava)
nativeScanBuilder.addAllPartitionSchema(partitionSchema.toIterable.asJava)

Some(result.setNativeScan(nativeScanBuilder).build())
Expand Down Expand Up @@ -3198,6 +3207,17 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde with CometExprShim
true
}

private def schema2Proto(
fields: Array[StructField]): Array[OperatorOuterClass.SparkStructField] = {
val fieldBuilder = OperatorOuterClass.SparkStructField.newBuilder()
fields.map(field => {
fieldBuilder.setName(field.name)
fieldBuilder.setDataType(serializeDataType(field.dataType).get)
fieldBuilder.setNullable(field.nullable)
fieldBuilder.build()
})
}

private def partition2Proto(
partition: FilePartition,
nativeScanBuilder: OperatorOuterClass.NativeScan.Builder,
Expand Down

0 comments on commit 0b43b23

Please sign in to comment.