From e0d807736b97ac1351dc34336c249bc7ac540558 Mon Sep 17 00:00:00 2001 From: Matt Butrovich Date: Thu, 5 Dec 2024 12:10:56 -0500 Subject: [PATCH] [comet-parquet-exec] Simplify schema logic for CometNativeScan (#1142) * Serialize original data schema and required schema, generate projection vector on the Java side. * Sending over more schema info like column names and nullability. * Using the new stuff in the proto. About to take the old out. * Remove old logic. * remove errant print. * Serialize original data schema and required schema, generate projection vector on the Java side. * Sending over more schema info like column names and nullability. * Using the new stuff in the proto. About to take the old out. * Remove old logic. * remove errant print. * Remove commented print. format. * Remove commented print. format. * Fix projection_vector to include partition_schema cols correctly. * Rename variable. --- .../core/src/execution/datafusion/planner.rs | 96 ++++++++----------- .../execution/datafusion/schema_adapter.rs | 3 +- native/proto/src/proto/operator.proto | 13 ++- .../apache/comet/serde/QueryPlanSerde.scala | 40 ++++++-- 4 files changed, 82 insertions(+), 70 deletions(-) diff --git a/native/core/src/execution/datafusion/planner.rs b/native/core/src/execution/datafusion/planner.rs index c5147d772..40fe4515c 100644 --- a/native/core/src/execution/datafusion/planner.rs +++ b/native/core/src/execution/datafusion/planner.rs @@ -121,7 +121,6 @@ use datafusion_physical_expr::LexOrdering; use itertools::Itertools; use jni::objects::GlobalRef; use num::{BigInt, ToPrimitive}; -use parquet::schema::parser::parse_message_type; use std::cmp::max; use std::{collections::HashMap, sync::Arc}; use url::Url; @@ -950,50 +949,28 @@ impl PhysicalPlanner { )) } OpStruct::NativeScan(scan) => { - let data_schema = parse_message_type(&scan.data_schema).unwrap(); - let required_schema = parse_message_type(&scan.required_schema).unwrap(); - - let data_schema_descriptor = - parquet::schema::types::SchemaDescriptor::new(Arc::new(data_schema)); - let data_schema_arrow = Arc::new( - parquet::arrow::schema::parquet_to_arrow_schema(&data_schema_descriptor, None) - .unwrap(), - ); - - let required_schema_descriptor = - parquet::schema::types::SchemaDescriptor::new(Arc::new(required_schema)); - let required_schema_arrow = Arc::new( - parquet::arrow::schema::parquet_to_arrow_schema( - &required_schema_descriptor, - None, - ) - .unwrap(), - ); - - let partition_schema_arrow = scan - .partition_schema + let data_schema = convert_spark_types_to_arrow_schema(scan.data_schema.as_slice()); + let required_schema: SchemaRef = + convert_spark_types_to_arrow_schema(scan.required_schema.as_slice()); + let partition_schema: SchemaRef = + convert_spark_types_to_arrow_schema(scan.partition_schema.as_slice()); + let projection_vector: Vec = scan + .projection_vector .iter() - .map(to_arrow_datatype) - .collect_vec(); - let partition_fields: Vec<_> = partition_schema_arrow - .iter() - .enumerate() - .map(|(idx, data_type)| { - Field::new(format!("part_{}", idx), data_type.clone(), true) - }) + .map(|offset| *offset as usize) .collect(); // Convert the Spark expressions to Physical expressions let data_filters: Result>, ExecutionError> = scan .data_filters .iter() - .map(|expr| self.create_expr(expr, Arc::clone(&required_schema_arrow))) + .map(|expr| self.create_expr(expr, Arc::clone(&required_schema))) .collect(); // Create a conjunctive form of the vector because ParquetExecBuilder takes // a single expression let data_filters = data_filters?; - let test_data_filters = data_filters.clone().into_iter().reduce(|left, right| { + let cnf_data_filters = data_filters.clone().into_iter().reduce(|left, right| { Arc::new(BinaryExpr::new( left, datafusion::logical_expr::Operator::And, @@ -1064,29 +1041,21 @@ impl PhysicalPlanner { assert_eq!(file_groups.len(), partition_count); let object_store_url = ObjectStoreUrl::local_filesystem(); + let partition_fields: Vec = partition_schema + .fields() + .iter() + .map(|field| { + Field::new(field.name(), field.data_type().clone(), field.is_nullable()) + }) + .collect_vec(); let mut file_scan_config = - FileScanConfig::new(object_store_url, Arc::clone(&data_schema_arrow)) + FileScanConfig::new(object_store_url, Arc::clone(&data_schema)) .with_file_groups(file_groups) .with_table_partition_cols(partition_fields); - // Check for projection, if so generate the vector and add to FileScanConfig. - let mut projection_vector: Vec = - Vec::with_capacity(required_schema_arrow.fields.len()); - // TODO: could be faster with a hashmap rather than iterating over data_schema_arrow with index_of. - required_schema_arrow.fields.iter().for_each(|field| { - projection_vector.push(data_schema_arrow.index_of(field.name()).unwrap()); - }); - - partition_schema_arrow - .iter() - .enumerate() - .for_each(|(idx, _)| { - projection_vector.push(idx + data_schema_arrow.fields.len()); - }); - assert_eq!( projection_vector.len(), - required_schema_arrow.fields.len() + partition_schema_arrow.len() + required_schema.fields.len() + partition_schema.fields.len() ); file_scan_config = file_scan_config.with_projection(Some(projection_vector)); @@ -1095,13 +1064,11 @@ impl PhysicalPlanner { table_parquet_options.global.pushdown_filters = true; table_parquet_options.global.reorder_filters = true; - let mut builder = ParquetExecBuilder::new(file_scan_config) - .with_table_parquet_options(table_parquet_options) - .with_schema_adapter_factory( - Arc::new(CometSchemaAdapterFactory::default()), - ); + let mut builder = ParquetExecBuilder::new(file_scan_config) + .with_table_parquet_options(table_parquet_options) + .with_schema_adapter_factory(Arc::new(CometSchemaAdapterFactory::default())); - if let Some(filter) = test_data_filters { + if let Some(filter) = cnf_data_filters { builder = builder.with_predicate(filter); } @@ -2309,6 +2276,23 @@ fn from_protobuf_eval_mode(value: i32) -> Result { } } +fn convert_spark_types_to_arrow_schema( + spark_types: &[spark_operator::SparkStructField], +) -> SchemaRef { + let arrow_fields = spark_types + .iter() + .map(|spark_type| { + Field::new( + String::clone(&spark_type.name), + to_arrow_datatype(spark_type.data_type.as_ref().unwrap()), + spark_type.nullable, + ) + }) + .collect_vec(); + let arrow_schema: SchemaRef = Arc::new(Schema::new(arrow_fields)); + arrow_schema +} + #[cfg(test)] mod tests { use std::{sync::Arc, task::Poll}; diff --git a/native/core/src/execution/datafusion/schema_adapter.rs b/native/core/src/execution/datafusion/schema_adapter.rs index 16d4b9d67..79dcd5c17 100644 --- a/native/core/src/execution/datafusion/schema_adapter.rs +++ b/native/core/src/execution/datafusion/schema_adapter.rs @@ -259,7 +259,8 @@ impl SchemaMapper for SchemaMapping { EvalMode::Legacy, "UTC", false, - )?.into_array(batch_col.len()) + )? + .into_array(batch_col.len()) // and if that works, return the field and column. .map(|new_col| (new_col, table_field.clone())) }) diff --git a/native/proto/src/proto/operator.proto b/native/proto/src/proto/operator.proto index 5e8a80f99..b4e12d123 100644 --- a/native/proto/src/proto/operator.proto +++ b/native/proto/src/proto/operator.proto @@ -61,6 +61,12 @@ message SparkFilePartition { repeated SparkPartitionedFile partitioned_file = 1; } +message SparkStructField { + string name = 1; + spark.spark_expression.DataType data_type = 2; + bool nullable = 3; +} + message Scan { repeated spark.spark_expression.DataType fields = 1; // The source of the scan (e.g. file scan, broadcast exchange, shuffle, etc). This @@ -75,11 +81,12 @@ message NativeScan { // is purely for informational purposes when viewing native query plans in // debug mode. string source = 2; - string required_schema = 3; - string data_schema = 4; - repeated spark.spark_expression.DataType partition_schema = 5; + repeated SparkStructField required_schema = 3; + repeated SparkStructField data_schema = 4; + repeated SparkStructField partition_schema = 5; repeated spark.spark_expression.Expr data_filters = 6; repeated SparkFilePartition file_partitions = 7; + repeated int64 projection_vector = 8; } message Projection { diff --git a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala index 7473e9326..29e50d734 100644 --- a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala +++ b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala @@ -36,7 +36,6 @@ import org.apache.spark.sql.execution._ import org.apache.spark.sql.execution.adaptive.{BroadcastQueryStageExec, ShuffleQueryStageExec} import org.apache.spark.sql.execution.aggregate.{BaseAggregateExec, HashAggregateExec, ObjectHashAggregateExec} import org.apache.spark.sql.execution.datasources.{FilePartition, FileScanRDD} -import org.apache.spark.sql.execution.datasources.parquet.SparkToParquetSchemaConverter import org.apache.spark.sql.execution.datasources.v2.{DataSourceRDD, DataSourceRDDPartition} import org.apache.spark.sql.execution.exchange.{BroadcastExchangeExec, ReusedExchangeExec, ShuffleExchangeExec} import org.apache.spark.sql.execution.joins.{BroadcastHashJoinExec, HashJoin, ShuffledHashJoinExec, SortMergeJoinExec} @@ -2520,18 +2519,28 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde with CometExprShim case _ => } - val requiredSchemaParquet = - new SparkToParquetSchemaConverter(conf).convert(scan.requiredSchema) - val dataSchemaParquet = - new SparkToParquetSchemaConverter(conf).convert(scan.relation.dataSchema) - val partitionSchema = scan.relation.partitionSchema.fields.flatMap { field => - serializeDataType(field.dataType) - } + val partitionSchema = schema2Proto(scan.relation.partitionSchema.fields) + val requiredSchema = schema2Proto(scan.requiredSchema.fields) + val dataSchema = schema2Proto(scan.relation.dataSchema.fields) + + val data_schema_idxs = scan.requiredSchema.fields.map(field => { + scan.relation.dataSchema.fieldIndex(field.name) + }) + val partition_schema_idxs = Array + .range( + scan.relation.dataSchema.fields.length, + scan.relation.dataSchema.length + scan.relation.partitionSchema.fields.length) + + val projection_vector = (data_schema_idxs ++ partition_schema_idxs).map(idx => + idx.toLong.asInstanceOf[java.lang.Long]) + + nativeScanBuilder.addAllProjectionVector(projection_vector.toIterable.asJava) + // In `CometScanRule`, we ensure partitionSchema is supported. assert(partitionSchema.length == scan.relation.partitionSchema.fields.length) - nativeScanBuilder.setRequiredSchema(requiredSchemaParquet.toString) - nativeScanBuilder.setDataSchema(dataSchemaParquet.toString) + nativeScanBuilder.addAllDataSchema(dataSchema.toIterable.asJava) + nativeScanBuilder.addAllRequiredSchema(requiredSchema.toIterable.asJava) nativeScanBuilder.addAllPartitionSchema(partitionSchema.toIterable.asJava) Some(result.setNativeScan(nativeScanBuilder).build()) @@ -3198,6 +3207,17 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde with CometExprShim true } + private def schema2Proto( + fields: Array[StructField]): Array[OperatorOuterClass.SparkStructField] = { + val fieldBuilder = OperatorOuterClass.SparkStructField.newBuilder() + fields.map(field => { + fieldBuilder.setName(field.name) + fieldBuilder.setDataType(serializeDataType(field.dataType).get) + fieldBuilder.setNullable(field.nullable) + fieldBuilder.build() + }) + } + private def partition2Proto( partition: FilePartition, nativeScanBuilder: OperatorOuterClass.NativeScan.Builder,