Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: [comet-parquet-exec] fix regressions original comet native scal implementation #1170

Merged
merged 4 commits into from
Dec 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -272,7 +272,7 @@ public void init() throws URISyntaxException, IOException {
requestedSchema =
CometParquetReadSupport.clipParquetSchema(
requestedSchema, sparkSchema, isCaseSensitive, useFieldId, ignoreMissingIds);
if (requestedSchema.getColumns().size() != sparkSchema.size()) {
if (requestedSchema.getFieldCount() != sparkSchema.size()) {
throw new IllegalArgumentException(
String.format(
"Spark schema has %d columns while " + "Parquet schema has %d columns",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -245,7 +245,7 @@ public void init() throws URISyntaxException, IOException {
requestedSchema =
CometParquetReadSupport.clipParquetSchema(
requestedSchema, sparkSchema, isCaseSensitive, useFieldId, ignoreMissingIds);
if (requestedSchema.getColumns().size() != sparkSchema.size()) {
if (requestedSchema.getFieldCount() != sparkSchema.size()) {
throw new IllegalArgumentException(
String.format(
"Spark schema has %d columns while " + "Parquet schema has %d columns",
Expand All @@ -267,9 +267,9 @@ public void init() throws URISyntaxException, IOException {
// ShimFileFormat.findRowIndexColumnIndexInSchema(sparkSchema);
for (int i = 0; i < requestedSchema.getFieldCount(); i++) {
Type t = requestedSchema.getFields().get(i);
Preconditions.checkState(
t.isPrimitive() && !t.isRepetition(Type.Repetition.REPEATED),
"Complex type is not supported");
// Preconditions.checkState(
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is for the second implementation. Will remove this later

// t.isPrimitive() && !t.isRepetition(Type.Repetition.REPEATED),
// "Complex type is not supported");
String[] colPath = paths.get(i);
if (nonPartitionFields[i].name().equals(ShimFileFormat.ROW_INDEX_TEMPORARY_COLUMN_NAME())) {
// Values of ROW_INDEX_TEMPORARY_COLUMN_NAME column are always populated with
Expand Down
16 changes: 8 additions & 8 deletions native/spark-expr/src/cast.rs
Original file line number Diff line number Diff line change
Expand Up @@ -571,7 +571,7 @@ impl SparkCastOptions {
eval_mode,
timezone: timezone.to_string(),
allow_incompat,
is_adapting_schema: false,
is_adapting_schema: false
}
}

Expand All @@ -583,6 +583,7 @@ impl SparkCastOptions {
is_adapting_schema: false,
}
}

}

/// Spark-compatible cast implementation. Defers to DataFusion's cast where that is known
Expand Down Expand Up @@ -2087,7 +2088,7 @@ mod tests {

let timezone = "UTC".to_string();
// test casting string dictionary array to timestamp array
let cast_options = SparkCastOptions::new(EvalMode::Legacy, timezone.clone(), false);
let cast_options = SparkCastOptions::new(EvalMode::Legacy, &timezone, false);
let result = cast_array(
dict_array,
&DataType::Timestamp(TimeUnit::Microsecond, Some(timezone.clone().into())),
Expand Down Expand Up @@ -2296,7 +2297,7 @@ mod tests {
fn test_cast_unsupported_timestamp_to_date() {
// Since datafusion uses chrono::Datetime internally not all dates representable by TimestampMicrosecondType are supported
let timestamps: PrimitiveArray<TimestampMicrosecondType> = vec![i64::MAX].into();
let cast_options = SparkCastOptions::new(EvalMode::Legacy, "UTC".to_string(), false);
let cast_options = SparkCastOptions::new(EvalMode::Legacy, "UTC", false);
let result = cast_array(
Arc::new(timestamps.with_timezone("Europe/Copenhagen")),
&DataType::Date32,
Expand All @@ -2309,7 +2310,7 @@ mod tests {
fn test_cast_invalid_timezone() {
let timestamps: PrimitiveArray<TimestampMicrosecondType> = vec![i64::MAX].into();
let cast_options =
SparkCastOptions::new(EvalMode::Legacy, "Not a valid timezone".to_string(), false);
SparkCastOptions::new(EvalMode::Legacy, "Not a valid timezone", false);
let result = cast_array(
Arc::new(timestamps.with_timezone("Europe/Copenhagen")),
&DataType::Date32,
Expand All @@ -2335,7 +2336,7 @@ mod tests {
let string_array = cast_array(
c,
&DataType::Utf8,
&SparkCastOptions::new(EvalMode::Legacy, "UTC".to_owned(), false),
&SparkCastOptions::new(EvalMode::Legacy, "UTC", false),
)
.unwrap();
let string_array = string_array.as_string::<i32>();
Expand Down Expand Up @@ -2400,10 +2401,9 @@ mod tests {
let cast_array = spark_cast(
ColumnarValue::Array(c),
&DataType::Struct(fields),
EvalMode::Legacy,
&SparkCastOptions::new(EvalMode::Legacy,
"UTC",
false,
false,
false)
)
.unwrap();
if let ColumnarValue::Array(cast_array) = cast_array {
Expand Down
5 changes: 4 additions & 1 deletion spark/src/main/scala/org/apache/comet/DataTypeSupport.scala
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,10 @@ trait DataTypeSupport {
BinaryType | StringType | _: DecimalType | DateType | TimestampType =>
true
case t: DataType if t.typeName == "timestamp_ntz" => true
case _: StructType => true
case _: StructType
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

With only the original scan enabled, this caused CometScan to be used for Struct types (which it did not support).

if CometConf.COMET_FULL_NATIVE_SCAN_ENABLED
.get() || CometConf.COMET_NATIVE_ARROW_SCAN_ENABLED.get() =>
true
case _ => false
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -131,8 +131,15 @@ case class CometScanExec(
// exposed for testing
lazy val bucketedScan: Boolean = wrapped.bucketedScan

override lazy val (outputPartitioning, outputOrdering): (Partitioning, Seq[SortOrder]) =
(wrapped.outputPartitioning, wrapped.outputOrdering)
override lazy val (outputPartitioning, outputOrdering): (Partitioning, Seq[SortOrder]) = {
Copy link
Contributor Author

@parthchandra parthchandra Dec 13, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@viirya - The previous fix to address outputPartitioning (using inputRDD) was not correct and caused multiple test failures. This is a different attempt (and at least all the tests pass).
If this is a bucketedScan, we fall back to the wrapped FileSourceScanLike implementation but for non bucketed case since FileSourceScanLike always returned 0 partitions, we override the behaviour, setting the num of partitions to the number of files.
I'm not entirely sure this covers all cases so please advise.

if (bucketedScan) {
(wrapped.outputPartitioning, wrapped.outputOrdering)
} else {
val files = selectedPartitions.flatMap(partition => partition.files)
val numPartitions = files.length
(UnknownPartitioning(numPartitions), wrapped.outputOrdering)
}
}

@transient
private lazy val pushedDownFilters = getPushedDownFilters(relation, dataFilters)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2217,6 +2217,7 @@ class CometExpressionSuite extends CometTestBase with AdaptiveSparkPlanHelper {
withSQLConf(
SQLConf.USE_V1_SOURCE_LIST.key -> v1List,
CometConf.COMET_ENABLED.key -> "true",
CometConf.COMET_FULL_NATIVE_SCAN_ENABLED.key -> "true",
CometConf.COMET_EXPLAIN_FALLBACK_ENABLED.key -> "true") {

val df = spark.read.parquet(dir.toString())
Expand Down Expand Up @@ -2249,6 +2250,7 @@ class CometExpressionSuite extends CometTestBase with AdaptiveSparkPlanHelper {
withSQLConf(
SQLConf.USE_V1_SOURCE_LIST.key -> v1List,
CometConf.COMET_ENABLED.key -> "true",
CometConf.COMET_FULL_NATIVE_SCAN_ENABLED.key -> "true",
CometConf.COMET_EXPLAIN_FALLBACK_ENABLED.key -> "true") {

val df = spark.read.parquet(dir.toString())
Expand Down
Loading