Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: [comet-parquet-exec] Schema adapter fixes #1139

Merged
merged 9 commits into from
Dec 6, 2024
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 3 additions & 5 deletions native/core/src/execution/datafusion/planner.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1095,11 +1095,9 @@ impl PhysicalPlanner {
table_parquet_options.global.pushdown_filters = true;
table_parquet_options.global.reorder_filters = true;

let mut builder = ParquetExecBuilder::new(file_scan_config)
.with_table_parquet_options(table_parquet_options)
.with_schema_adapter_factory(
Arc::new(CometSchemaAdapterFactory::default()),
);
let mut builder = ParquetExecBuilder::new(file_scan_config)
.with_table_parquet_options(table_parquet_options)
.with_schema_adapter_factory(Arc::new(CometSchemaAdapterFactory::default()));

if let Some(filter) = test_data_filters {
builder = builder.with_predicate(filter);
Expand Down
69 changes: 36 additions & 33 deletions native/core/src/execution/datafusion/schema_adapter.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@

use arrow::compute::can_cast_types;
use arrow_array::{new_null_array, Array, RecordBatch, RecordBatchOptions};
use arrow_schema::{DataType, Schema, SchemaRef};
use arrow_schema::{DataType, Schema, SchemaRef, TimeUnit};
use datafusion::datasource::schema_adapter::{SchemaAdapter, SchemaAdapterFactory, SchemaMapper};
use datafusion_comet_spark_expr::{spark_cast, EvalMode};
use datafusion_common::plan_err;
Expand All @@ -38,11 +38,11 @@ impl SchemaAdapterFactory for CometSchemaAdapterFactory {
/// schema.
fn create(
&self,
projected_table_schema: SchemaRef,
required_schema: SchemaRef,
table_schema: SchemaRef,
) -> Box<dyn SchemaAdapter> {
Box::new(CometSchemaAdapter {
projected_table_schema,
required_schema,
table_schema,
})
}
Expand All @@ -54,7 +54,7 @@ impl SchemaAdapterFactory for CometSchemaAdapterFactory {
pub struct CometSchemaAdapter {
/// The schema for the table, projected to include only the fields being output (projected) by the
/// associated ParquetExec
projected_table_schema: SchemaRef,
required_schema: SchemaRef,
/// The entire table schema for the table we're using this to adapt.
///
/// This is used to evaluate any filters pushed down into the scan
Expand All @@ -69,7 +69,7 @@ impl SchemaAdapter for CometSchemaAdapter {
///
/// Panics if index is not in range for the table schema
fn map_column_index(&self, index: usize, file_schema: &Schema) -> Option<usize> {
let field = self.projected_table_schema.field(index);
let field = self.required_schema.field(index);
Some(file_schema.fields.find(field.name())?.0)
}

Expand All @@ -87,40 +87,29 @@ impl SchemaAdapter for CometSchemaAdapter {
file_schema: &Schema,
) -> datafusion_common::Result<(Arc<dyn SchemaMapper>, Vec<usize>)> {
let mut projection = Vec::with_capacity(file_schema.fields().len());
let mut field_mappings = vec![None; self.projected_table_schema.fields().len()];
let mut field_mappings = vec![None; self.required_schema.fields().len()];

for (file_idx, file_field) in file_schema.fields.iter().enumerate() {
if let Some((table_idx, table_field)) =
self.projected_table_schema.fields().find(file_field.name())
self.required_schema.fields().find(file_field.name())
{
// workaround for struct casting
match (file_field.data_type(), table_field.data_type()) {
// TODO need to use Comet cast logic to determine which casts are supported,
// but for now just add a hack to support casting between struct types
(DataType::Struct(_), DataType::Struct(_)) => {
field_mappings[table_idx] = Some(projection.len());
projection.push(file_idx);
}
_ => {
if can_cast_types(file_field.data_type(), table_field.data_type()) {
field_mappings[table_idx] = Some(projection.len());
projection.push(file_idx);
} else {
return plan_err!(
"Cannot cast file schema field {} of type {:?} to table schema field of type {:?}",
file_field.name(),
file_field.data_type(),
table_field.data_type()
);
}
}
if spark_can_cast_types(file_field.data_type(), table_field.data_type()) {
field_mappings[table_idx] = Some(projection.len());
projection.push(file_idx);
} else {
return plan_err!(
"Cannot cast file schema field {} of type {:?} to required schema field of type {:?}",
file_field.name(),
file_field.data_type(),
table_field.data_type()
);
}
}
}

Ok((
Arc::new(SchemaMapping {
projected_table_schema: self.projected_table_schema.clone(),
required_schema: self.required_schema.clone(),
field_mappings,
table_schema: self.table_schema.clone(),
}),
Expand All @@ -129,6 +118,19 @@ impl SchemaAdapter for CometSchemaAdapter {
}
}

pub fn spark_can_cast_types(from_type: &DataType, to_type: &DataType) -> bool {
// TODO add all Spark cast rules (they are currently implemented in
// org.apache.comet.expressions.CometCast#isSupported in JVM side)
match (from_type, to_type) {
(DataType::Struct(_), DataType::Struct(_)) => {
// workaround for struct casting
true
}
(_, DataType::Timestamp(TimeUnit::Nanosecond, _)) => false,
_ => can_cast_types(from_type, to_type),
}
}

// TODO SchemaMapping is mostly copied from DataFusion but calls spark_cast
// instead of arrow cast - can we reduce the amount of code copied here and make
// the DataFusion version more extensible?
Expand Down Expand Up @@ -161,7 +163,7 @@ impl SchemaAdapter for CometSchemaAdapter {
pub struct SchemaMapping {
/// The schema of the table. This is the expected schema after conversion
/// and it should match the schema of the query result.
projected_table_schema: SchemaRef,
required_schema: SchemaRef,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1 :)

/// Mapping from field index in `projected_table_schema` to index in
/// projected file_schema.
///
Expand All @@ -185,7 +187,7 @@ impl SchemaMapper for SchemaMapping {
let batch_cols = batch.columns().to_vec();

let cols = self
.projected_table_schema
.required_schema
// go through each field in the projected schema
.fields()
.iter()
Expand Down Expand Up @@ -218,7 +220,7 @@ impl SchemaMapper for SchemaMapping {
// Necessary to handle empty batches
let options = RecordBatchOptions::new().with_row_count(Some(batch.num_rows()));

let schema = self.projected_table_schema.clone();
let schema = self.required_schema.clone();
let record_batch = RecordBatch::try_new_with_options(schema, cols, &options)?;
Ok(record_batch)
}
Expand Down Expand Up @@ -259,7 +261,8 @@ impl SchemaMapper for SchemaMapping {
EvalMode::Legacy,
"UTC",
false,
)?.into_array(batch_col.len())
)?
.into_array(batch_col.len())
// and if that works, return the field and column.
.map(|new_col| (new_col, table_field.clone()))
})
Expand Down
41 changes: 38 additions & 3 deletions native/spark-expr/src/utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,14 +19,15 @@ use arrow_array::{
cast::as_primitive_array,
types::{Int32Type, TimestampMicrosecondType},
};
use arrow_schema::{ArrowError, DataType};
use arrow_schema::{ArrowError, DataType, TimeUnit};
use std::sync::Arc;

use crate::timezone::Tz;
use arrow::{
array::{as_dictionary_array, Array, ArrayRef, PrimitiveArray},
temporal_conversions::as_datetime,
};
use arrow_array::types::TimestampMillisecondType;
use chrono::{DateTime, Offset, TimeZone};

/// Preprocesses input arrays to add timezone information from Spark to Arrow array datatype or
Expand Down Expand Up @@ -70,6 +71,9 @@ pub fn array_with_timezone(
Some(DataType::Timestamp(_, Some(_))) => {
timestamp_ntz_to_timestamp(array, timezone.as_str(), Some(timezone.as_str()))
}
Some(DataType::Timestamp(_, None)) => {
timestamp_ntz_to_timestamp(array, timezone.as_str(), None)
}
_ => {
// Not supported
panic!(
Expand All @@ -80,7 +84,7 @@ pub fn array_with_timezone(
}
}
}
DataType::Timestamp(_, Some(_)) => {
DataType::Timestamp(TimeUnit::Microsecond, Some(_)) => {
assert!(!timezone.is_empty());
let array = as_primitive_array::<TimestampMicrosecondType>(&array);
let array_with_timezone = array.clone().with_timezone(timezone.clone());
Expand All @@ -92,6 +96,18 @@ pub fn array_with_timezone(
_ => Ok(array),
}
}
DataType::Timestamp(TimeUnit::Millisecond, Some(_)) => {
assert!(!timezone.is_empty());
let array = as_primitive_array::<TimestampMillisecondType>(&array);
let array_with_timezone = array.clone().with_timezone(timezone.clone());
let array = Arc::new(array_with_timezone) as ArrayRef;
match to_type {
Some(DataType::Utf8) | Some(DataType::Date32) => {
pre_timestamp_cast(array, timezone)
}
_ => Ok(array),
}
}
DataType::Dictionary(_, value_type)
if matches!(value_type.as_ref(), &DataType::Timestamp(_, _)) =>
{
Expand Down Expand Up @@ -127,7 +143,7 @@ fn timestamp_ntz_to_timestamp(
) -> Result<ArrayRef, ArrowError> {
assert!(!tz.is_empty());
match array.data_type() {
DataType::Timestamp(_, None) => {
DataType::Timestamp(TimeUnit::Microsecond, None) => {
let array = as_primitive_array::<TimestampMicrosecondType>(&array);
let tz: Tz = tz.parse()?;
let array: PrimitiveArray<TimestampMicrosecondType> = array.try_unary(|value| {
Expand All @@ -146,6 +162,25 @@ fn timestamp_ntz_to_timestamp(
};
Ok(Arc::new(array_with_tz))
}
DataType::Timestamp(TimeUnit::Millisecond, None) => {
let array = as_primitive_array::<TimestampMillisecondType>(&array);
let tz: Tz = tz.parse()?;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this called frequently (per row)? timezone parse is somewhat expensive (and does not change for a session).

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is once per array, but I think the parsing could happen once during planning rather than per batch/array.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Makes sense, we can defer this for the moment.

let array: PrimitiveArray<TimestampMillisecondType> = array.try_unary(|value| {
as_datetime::<TimestampMillisecondType>(value)
.ok_or_else(|| datetime_cast_err(value))
.map(|local_datetime| {
let datetime: DateTime<Tz> =
tz.from_local_datetime(&local_datetime).unwrap();
datetime.timestamp_millis()
})
})?;
let array_with_tz = if let Some(to_tz) = to_timezone {
array.with_timezone(to_tz)
} else {
array
};
Ok(Arc::new(array_with_tz))
}
_ => Ok(array),
}
}
Expand Down
Loading