From 801176e8b92389a8415f82932d842fcd2996ac62 Mon Sep 17 00:00:00 2001 From: Parth Chandra Date: Tue, 9 Apr 2024 16:22:53 -0700 Subject: [PATCH] Add extended explain info to Comet plan --- dev/ensure-jars-have-correct-contents.sh | 2 + pom.xml | 3 + .../CometTPCDSQueriesList-results.txt | 622 ++++++ .../CometTPCHQueriesList-results.txt | 121 ++ .../org/apache/comet/CometExplainInfo.scala | 81 + .../comet/CometSparkSessionExtensions.scala | 314 +++- .../apache/comet/ExtendedExplainInfo.scala | 85 + .../apache/comet/serde/QueryPlanSerde.scala | 1670 ++++++++++------- .../ShimCometSparkSessionExtensions.scala | 15 +- .../spark/sql/ExtendedExplainGenerator.scala | 32 + .../spark/sql/comet/CometExecUtils.scala | 4 +- .../CometTakeOrderedAndProjectExec.scala | 11 +- .../shuffle/CometShuffleExchangeExec.scala | 5 +- .../apache/comet/CometExpressionSuite.scala | 51 + .../apache/spark/sql/CometTPCQueryBase.scala | 3 + .../spark/sql/CometTPCQueryListBase.scala | 20 +- .../org/apache/spark/sql/CometTestBase.scala | 32 +- 17 files changed, 2341 insertions(+), 730 deletions(-) create mode 100644 spark/inspections/CometTPCDSQueriesList-results.txt create mode 100644 spark/inspections/CometTPCHQueriesList-results.txt create mode 100644 spark/src/main/scala/org/apache/comet/CometExplainInfo.scala create mode 100644 spark/src/main/scala/org/apache/comet/ExtendedExplainInfo.scala create mode 100644 spark/src/main/scala/org/apache/spark/sql/ExtendedExplainGenerator.scala diff --git a/dev/ensure-jars-have-correct-contents.sh b/dev/ensure-jars-have-correct-contents.sh index 5543093ff6..1f97d2d4a7 100755 --- a/dev/ensure-jars-have-correct-contents.sh +++ b/dev/ensure-jars-have-correct-contents.sh @@ -78,6 +78,8 @@ allowed_expr+="|^org/apache/spark/shuffle/sort/CometShuffleExternalSorter.*$" allowed_expr+="|^org/apache/spark/shuffle/sort/RowPartition.class$" allowed_expr+="|^org/apache/spark/shuffle/comet/.*$" allowed_expr+="|^org/apache/spark/sql/$" +# allow ExplainPlanGenerator trait since it may not be available in older Spark versions +allowed_expr+="|^org/apache/spark/sql/ExtendedExplainGenerator.*$" allowed_expr+="|^org/apache/spark/CometPlugin.class$" allowed_expr+="|^org/apache/spark/CometDriverPlugin.*$" allowed_expr+="|^org/apache/spark/CometTaskMemoryManager.class$" diff --git a/pom.xml b/pom.xml index c7e417dd1b..2b04c5e3e6 100644 --- a/pom.xml +++ b/pom.xml @@ -927,6 +927,9 @@ under the License. javax.annotation.meta.TypeQualifierValidator org.apache.parquet.filter2.predicate.SparkFilterApi + + + org.apache.spark.sql.ExtendedExplainGenerator diff --git a/spark/inspections/CometTPCDSQueriesList-results.txt b/spark/inspections/CometTPCDSQueriesList-results.txt new file mode 100644 index 0000000000..9ca4eda36c --- /dev/null +++ b/spark/inspections/CometTPCDSQueriesList-results.txt @@ -0,0 +1,622 @@ +Query: q1. Comet Exec: Enabled (CometProject, CometFilter) +Query: q1: ExplainInfo: +BroadcastHashJoin is not supported + +Query: q2. Comet Exec: Enabled (CometUnion, CometProject, CometFilter) +Query: q2: ExplainInfo: +BroadcastHashJoin is not supported +Shuffle: unsupported Spark partitioning: org.apache.spark.sql.catalyst.plans.physical.RangePartitioning + +Query: q3. Comet Exec: Enabled (CometProject, CometFilter) +Query: q3: ExplainInfo: +BroadcastHashJoin is not supported + +Query: q4. Comet Exec: Enabled (CometProject, CometFilter) +Query: q4: ExplainInfo: +BroadcastHashJoin is not supported + +Query: q5. Comet Exec: Enabled (CometUnion, CometProject, CometFilter) +Query: q5: ExplainInfo: +BroadcastHashJoin is not supported +Not all subqueries for union are native + +Query: q6. Comet Exec: Enabled (CometHashAggregate, CometProject, CometFilter) +Query: q6: ExplainInfo: +BroadcastHashJoin is not supported + +Query: q7. Comet Exec: Enabled (CometProject, CometFilter) +Query: q7: ExplainInfo: +BroadcastHashJoin is not supported + +Query: q8. Comet Exec: Enabled (CometProject, CometFilter) +Query: q8: ExplainInfo: +BroadcastHashJoin is not supported + +Query: q9. Comet Exec: Enabled (CometFilter) +Query: q9: ExplainInfo: +named_struct is not supported +getstructfield is not supported + +Query: q10. Comet Exec: Enabled (CometProject, CometFilter) +Query: q10: ExplainInfo: +BroadcastHashJoin is not supported + +Query: q11. Comet Exec: Enabled (CometProject, CometFilter) +Query: q11: ExplainInfo: +BroadcastHashJoin is not supported + +Query: q12. Comet Exec: Enabled (CometProject, CometFilter) +Query: q12: ExplainInfo: +BroadcastHashJoin is not supported +Window is not supported + +Query: q13. Comet Exec: Enabled (CometProject, CometFilter) +Query: q13: ExplainInfo: +BroadcastHashJoin is not supported + +Query: q14a. Comet Exec: Enabled (CometProject, CometFilter) +Query: q14a: ExplainInfo: +BroadcastHashJoin is not supported +Not all subqueries for union are native + +Query: q14b. Comet Exec: Enabled (CometProject, CometFilter) +Query: q14b: ExplainInfo: +BroadcastHashJoin is not supported +Not all subqueries for union are native + +Query: q15. Comet Exec: Enabled (CometProject, CometFilter) +Query: q15: ExplainInfo: +BroadcastHashJoin is not supported + +Query: q16. Comet Exec: Enabled (CometProject, CometFilter) +Query: q16: ExplainInfo: +BroadcastHashJoin is not supported + +Query: q17. Comet Exec: Enabled (CometProject, CometFilter) +Query: q17: ExplainInfo: +BroadcastHashJoin is not supported + +Query: q18. Comet Exec: Enabled (CometProject, CometFilter) +Query: q18: ExplainInfo: +BroadcastHashJoin is not supported + +Query: q19. Comet Exec: Enabled (CometProject, CometFilter) +Query: q19: ExplainInfo: +BroadcastHashJoin is not supported + +Query: q20. Comet Exec: Enabled (CometProject, CometFilter) +Query: q20: ExplainInfo: +BroadcastHashJoin is not supported +Window is not supported + +Query: q21. Comet Exec: Enabled (CometProject, CometFilter) +Query: q21: ExplainInfo: +BroadcastHashJoin is not supported + +Query: q22. Comet Exec: Enabled (CometProject, CometFilter) +Query: q22: ExplainInfo: +BroadcastHashJoin is not supported + +Query: q23a. Comet Exec: Enabled (CometProject, CometFilter) +Query: q23a: ExplainInfo: +BroadcastHashJoin is not supported +Not all subqueries for union are native + +Query: q23b. Comet Exec: Enabled (CometProject, CometFilter) +Query: q23b: ExplainInfo: +BroadcastHashJoin is not supported +Not all subqueries for union are native + +Query: q24a. Comet Exec: Enabled (CometProject, CometFilter) +Query: q24a: ExplainInfo: +BroadcastHashJoin is not supported + +Query: q24b. Comet Exec: Enabled (CometProject, CometFilter) +Query: q24b: ExplainInfo: +BroadcastHashJoin is not supported + +Query: q25. Comet Exec: Enabled (CometProject, CometFilter) +Query: q25: ExplainInfo: +BroadcastHashJoin is not supported + +Query: q26. Comet Exec: Enabled (CometProject, CometFilter) +Query: q26: ExplainInfo: +BroadcastHashJoin is not supported + +Query: q27. Comet Exec: Enabled (CometProject, CometFilter) +Query: q27: ExplainInfo: +BroadcastHashJoin is not supported + +Query: q28. Comet Exec: Enabled (CometHashAggregate, CometProject, CometFilter) +Query: q28: ExplainInfo: +Unsupported aggregation mode PartialMerge +BroadcastNestedLoopJoin is not supported + +Query: q29. Comet Exec: Enabled (CometProject, CometFilter) +Query: q29: ExplainInfo: +BroadcastHashJoin is not supported + +Query: q30. Comet Exec: Enabled (CometProject, CometFilter) +Query: q30: ExplainInfo: +BroadcastHashJoin is not supported + +Query: q31. Comet Exec: Enabled (CometFilter) +Query: q31: ExplainInfo: +BroadcastHashJoin is not supported +Shuffle: unsupported Spark partitioning: org.apache.spark.sql.catalyst.plans.physical.RangePartitioning + +Query: q32. Comet Exec: Enabled (CometProject, CometFilter) +Query: q32: ExplainInfo: +BroadcastHashJoin is not supported + +Query: q33. Comet Exec: Enabled (CometProject, CometFilter) +Query: q33: ExplainInfo: +BroadcastHashJoin is not supported +Not all subqueries for union are native + +Query: q34. Comet Exec: Enabled (CometProject, CometFilter) +Query: q34: ExplainInfo: +BroadcastHashJoin is not supported +Shuffle: unsupported Spark partitioning: org.apache.spark.sql.catalyst.plans.physical.RangePartitioning + +Query: q35. Comet Exec: Enabled (CometProject, CometFilter) +Query: q35: ExplainInfo: +BroadcastHashJoin is not supported + +Query: q36. Comet Exec: Enabled (CometProject, CometFilter) +Query: q36: ExplainInfo: +BroadcastHashJoin is not supported +Window is not supported + +Query: q37. Comet Exec: Enabled (CometProject, CometFilter) +Query: q37: ExplainInfo: +BroadcastHashJoin is not supported + +Query: q38. Comet Exec: Enabled (CometProject, CometFilter) +Query: q38: ExplainInfo: +BroadcastHashJoin is not supported + +Query: q39a. Comet Exec: Enabled (CometProject, CometFilter) +Query: q39a: ExplainInfo: +BroadcastHashJoin is not supported +Shuffle: unsupported Spark partitioning: org.apache.spark.sql.catalyst.plans.physical.RangePartitioning + +Query: q39b. Comet Exec: Enabled (CometProject, CometFilter) +Query: q39b: ExplainInfo: +BroadcastHashJoin is not supported +Shuffle: unsupported Spark partitioning: org.apache.spark.sql.catalyst.plans.physical.RangePartitioning + +Query: q40. Comet Exec: Enabled (CometProject, CometFilter) +Query: q40: ExplainInfo: +BroadcastHashJoin is not supported + +Query: q41. Comet Exec: Enabled (CometHashAggregate, CometProject, CometFilter) +Query: q41: ExplainInfo: +BroadcastHashJoin is not supported + +Query: q42. Comet Exec: Enabled (CometProject, CometFilter) +Query: q42: ExplainInfo: +BroadcastHashJoin is not supported + +Query: q43. Comet Exec: Enabled (CometProject, CometFilter) +Query: q43: ExplainInfo: +BroadcastHashJoin is not supported + +Query: q44. Comet Exec: Enabled (CometHashAggregate, CometSort, CometProject, CometFilter) +Query: q44: ExplainInfo: +Window is not supported +BroadcastHashJoin is not supported + +Query: q45. Comet Exec: Enabled (CometProject, CometFilter) +Query: q45: ExplainInfo: +BroadcastHashJoin is not supported + +Query: q46. Comet Exec: Enabled (CometProject, CometFilter) +Query: q46: ExplainInfo: +BroadcastHashJoin is not supported + +Query: q47. Comet Exec: Enabled (CometProject, CometFilter) +Query: q47: ExplainInfo: +BroadcastHashJoin is not supported +Window is not supported + +Query: q48. Comet Exec: Enabled (CometProject, CometFilter) +Query: q48: ExplainInfo: +BroadcastHashJoin is not supported + +Query: q49. Comet Exec: Enabled (CometProject, CometFilter) +Query: q49: ExplainInfo: +BroadcastHashJoin is not supported +Window is not supported +Not all subqueries for union are native + +Query: q50. Comet Exec: Enabled (CometProject, CometFilter) +Query: q50: ExplainInfo: +BroadcastHashJoin is not supported + +Query: q51. Comet Exec: Enabled (CometProject, CometFilter) +Query: q51: ExplainInfo: +BroadcastHashJoin is not supported +Window is not supported +SortMergeJoin is not supported + +Query: q52. Comet Exec: Enabled (CometProject, CometFilter) +Query: q52: ExplainInfo: +BroadcastHashJoin is not supported + +Query: q53. Comet Exec: Enabled (CometProject, CometFilter) +Query: q53: ExplainInfo: +BroadcastHashJoin is not supported +Window is not supported + +Query: q54. Comet Exec: Enabled (CometUnion, CometProject, CometFilter) +Query: q54: ExplainInfo: +BroadcastHashJoin is not supported + +Query: q55. Comet Exec: Enabled (CometProject, CometFilter) +Query: q55: ExplainInfo: +BroadcastHashJoin is not supported + +Query: q56. Comet Exec: Enabled (CometProject, CometFilter) +Query: q56: ExplainInfo: +BroadcastHashJoin is not supported +Not all subqueries for union are native + +Query: q57. Comet Exec: Enabled (CometProject, CometFilter) +Query: q57: ExplainInfo: +BroadcastHashJoin is not supported +Window is not supported + +Query: q58. Comet Exec: Enabled (CometProject, CometFilter) +Query: q58: ExplainInfo: +BroadcastHashJoin is not supported + +Query: q59. Comet Exec: Enabled (CometProject, CometFilter) +Query: q59: ExplainInfo: +BroadcastHashJoin is not supported + +Query: q60. Comet Exec: Enabled (CometProject, CometFilter) +Query: q60: ExplainInfo: +BroadcastHashJoin is not supported +Not all subqueries for union are native + +Query: q61. Comet Exec: Enabled (CometProject, CometFilter) +Query: q61: ExplainInfo: +BroadcastHashJoin is not supported +BroadcastNestedLoopJoin is not supported + +Query: q62. Comet Exec: Enabled (CometProject, CometFilter) +Query: q62: ExplainInfo: +BroadcastHashJoin is not supported + +Query: q63. Comet Exec: Enabled (CometProject, CometFilter) +Query: q63: ExplainInfo: +BroadcastHashJoin is not supported +Window is not supported + +Query: q64. Comet Exec: Enabled (CometProject, CometFilter) +Query: q64: ExplainInfo: +BroadcastHashJoin is not supported +Shuffle: unsupported Spark partitioning: org.apache.spark.sql.catalyst.plans.physical.RangePartitioning + +Query: q65. Comet Exec: Enabled (CometProject, CometFilter) +Query: q65: ExplainInfo: +BroadcastHashJoin is not supported + +Query: q66. Comet Exec: Enabled (CometProject, CometFilter) +Query: q66: ExplainInfo: +BroadcastHashJoin is not supported +Not all subqueries for union are native + +Query: q67. Comet Exec: Enabled (CometProject, CometFilter) +Query: q67: ExplainInfo: +BroadcastHashJoin is not supported +Window is not supported + +Query: q68. Comet Exec: Enabled (CometProject, CometFilter) +Query: q68: ExplainInfo: +BroadcastHashJoin is not supported + +Query: q69. Comet Exec: Enabled (CometProject, CometFilter) +Query: q69: ExplainInfo: +BroadcastHashJoin is not supported + +Query: q70. Comet Exec: Enabled (CometProject, CometFilter) +Query: q70: ExplainInfo: +BroadcastHashJoin is not supported +Window is not supported + +Query: q71. Comet Exec: Enabled (CometProject, CometFilter) +Query: q71: ExplainInfo: +BroadcastHashJoin is not supported +Not all subqueries for union are native +Shuffle: unsupported Spark partitioning: org.apache.spark.sql.catalyst.plans.physical.RangePartitioning + +Query: q72. Comet Exec: Enabled (CometProject, CometFilter) +Query: q72: ExplainInfo: +BroadcastHashJoin is not supported + +Query: q73. Comet Exec: Enabled (CometProject, CometFilter) +Query: q73: ExplainInfo: +BroadcastHashJoin is not supported +Shuffle: unsupported Spark partitioning: org.apache.spark.sql.catalyst.plans.physical.RangePartitioning + +Query: q74. Comet Exec: Enabled (CometProject, CometFilter) +Query: q74: ExplainInfo: +BroadcastHashJoin is not supported + +Query: q75. Comet Exec: Enabled (CometProject, CometFilter) +Query: q75: ExplainInfo: +BroadcastHashJoin is not supported +Not all subqueries for union are native + +Query: q76. Comet Exec: Enabled (CometProject, CometFilter) +Query: q76: ExplainInfo: +BroadcastHashJoin is not supported +Not all subqueries for union are native + +Query: q77. Comet Exec: Enabled (CometProject, CometFilter) +Query: q77: ExplainInfo: +BroadcastHashJoin is not supported +BroadcastNestedLoopJoin is not supported +Not all subqueries for union are native + +Query: q78. Comet Exec: Enabled (CometFilter) +Query: q78: ExplainInfo: +BroadcastHashJoin is not supported +Comet does not support Spark's BigDecimal rounding + +Query: q79. Comet Exec: Enabled (CometProject, CometFilter) +Query: q79: ExplainInfo: +BroadcastHashJoin is not supported + +Query: q80. Comet Exec: Enabled (CometProject, CometFilter) +Query: q80: ExplainInfo: +BroadcastHashJoin is not supported +Not all subqueries for union are native + +Query: q81. Comet Exec: Enabled (CometProject, CometFilter) +Query: q81: ExplainInfo: +BroadcastHashJoin is not supported + +Query: q82. Comet Exec: Enabled (CometProject, CometFilter) +Query: q82: ExplainInfo: +BroadcastHashJoin is not supported + +Query: q83. Comet Exec: Enabled (CometProject, CometFilter) +Query: q83: ExplainInfo: +BroadcastHashJoin is not supported + +Query: q84. Comet Exec: Enabled (CometProject, CometFilter) +Query: q84: ExplainInfo: +BroadcastHashJoin is not supported + +Query: q85. Comet Exec: Enabled (CometProject, CometFilter) +Query: q85: ExplainInfo: +BroadcastHashJoin is not supported + +Query: q86. Comet Exec: Enabled (CometProject, CometFilter) +Query: q86: ExplainInfo: +BroadcastHashJoin is not supported +Window is not supported + +Query: q87. Comet Exec: Enabled (CometProject, CometFilter) +Query: q87: ExplainInfo: +BroadcastHashJoin is not supported + +Query: q88. Comet Exec: Enabled (CometProject, CometFilter) +Query: q88: ExplainInfo: +BroadcastHashJoin is not supported +BroadcastNestedLoopJoin is not supported + +Query: q89. Comet Exec: Enabled (CometProject, CometFilter) +Query: q89: ExplainInfo: +BroadcastHashJoin is not supported +Window is not supported + +Query: q90. Comet Exec: Enabled (CometProject, CometFilter) +Query: q90: ExplainInfo: +BroadcastHashJoin is not supported +BroadcastNestedLoopJoin is not supported + +Query: q91. Comet Exec: Enabled (CometProject, CometFilter) +Query: q91: ExplainInfo: +BroadcastHashJoin is not supported +Shuffle: unsupported Spark partitioning: org.apache.spark.sql.catalyst.plans.physical.RangePartitioning + +Query: q92. Comet Exec: Enabled (CometProject, CometFilter) +Query: q92: ExplainInfo: +BroadcastHashJoin is not supported + +Query: q93. Comet Exec: Enabled (CometProject, CometFilter) +Query: q93: ExplainInfo: +BroadcastHashJoin is not supported + +Query: q94. Comet Exec: Enabled (CometProject, CometFilter) +Query: q94: ExplainInfo: +BroadcastHashJoin is not supported + +Query: q95. Comet Exec: Enabled (CometProject, CometFilter) +Query: q95: ExplainInfo: +BroadcastHashJoin is not supported + +Query: q96. Comet Exec: Enabled (CometProject, CometFilter) +Query: q96: ExplainInfo: +BroadcastHashJoin is not supported + +Query: q97. Comet Exec: Enabled (CometProject, CometFilter) +Query: q97: ExplainInfo: +BroadcastHashJoin is not supported +SortMergeJoin is not supported + +Query: q98. Comet Exec: Enabled (CometProject, CometFilter) +Query: q98: ExplainInfo: +BroadcastHashJoin is not supported +Window is not supported +Shuffle: unsupported Spark partitioning: org.apache.spark.sql.catalyst.plans.physical.RangePartitioning + +Query: q99. Comet Exec: Enabled (CometProject, CometFilter) +Query: q99: ExplainInfo: +BroadcastHashJoin is not supported + +Query: q5a-v2.7. Comet Exec: Enabled (CometUnion, CometProject, CometFilter) +Query: q5a-v2.7: ExplainInfo: +BroadcastHashJoin is not supported +Not all subqueries for union are native + +Query: q6-v2.7. Comet Exec: Enabled (CometHashAggregate, CometProject, CometFilter) +Query: q6-v2.7: ExplainInfo: +BroadcastHashJoin is not supported + +Query: q10a-v2.7. Comet Exec: Enabled (CometProject, CometFilter) +Query: q10a-v2.7: ExplainInfo: +BroadcastHashJoin is not supported +Not all subqueries for union are native + +Query: q11-v2.7. Comet Exec: Enabled (CometProject, CometFilter) +Query: q11-v2.7: ExplainInfo: +BroadcastHashJoin is not supported + +Query: q12-v2.7. Comet Exec: Enabled (CometProject, CometFilter) +Query: q12-v2.7: ExplainInfo: +BroadcastHashJoin is not supported +Window is not supported + +Query: q14-v2.7. Comet Exec: Enabled (CometProject, CometFilter) +Query: q14-v2.7: ExplainInfo: +BroadcastHashJoin is not supported +Not all subqueries for union are native + +Query: q14a-v2.7. Comet Exec: Enabled (CometProject, CometFilter) +Query: q14a-v2.7: ExplainInfo: +BroadcastHashJoin is not supported +Not all subqueries for union are native + +Query: q18a-v2.7. Comet Exec: Enabled (CometProject, CometFilter) +Query: q18a-v2.7: ExplainInfo: +BroadcastHashJoin is not supported +Not all subqueries for union are native + +Query: q20-v2.7. Comet Exec: Enabled (CometProject, CometFilter) +Query: q20-v2.7: ExplainInfo: +BroadcastHashJoin is not supported +Window is not supported + +Query: q22-v2.7. Comet Exec: Enabled (CometProject, CometFilter) +Query: q22-v2.7: ExplainInfo: +BroadcastHashJoin is not supported +BroadcastNestedLoopJoin is not supported + +Query: q22a-v2.7. Comet Exec: Enabled (CometProject, CometFilter) +Query: q22a-v2.7: ExplainInfo: +BroadcastHashJoin is not supported +Not all subqueries for union are native + +Query: q24-v2.7. Comet Exec: Enabled (CometProject, CometFilter) +Query: q24-v2.7: ExplainInfo: +BroadcastHashJoin is not supported +Shuffle: unsupported Spark partitioning: org.apache.spark.sql.catalyst.plans.physical.RangePartitioning + +Query: q27a-v2.7. Comet Exec: Enabled (CometProject, CometFilter) +Query: q27a-v2.7: ExplainInfo: +BroadcastHashJoin is not supported +Not all subqueries for union are native + +Query: q34-v2.7. Comet Exec: Enabled (CometProject, CometFilter) +Query: q34-v2.7: ExplainInfo: +BroadcastHashJoin is not supported +Shuffle: unsupported Spark partitioning: org.apache.spark.sql.catalyst.plans.physical.RangePartitioning + +Query: q35-v2.7. Comet Exec: Enabled (CometProject, CometFilter) +Query: q35-v2.7: ExplainInfo: +BroadcastHashJoin is not supported + +Query: q35a-v2.7. Comet Exec: Enabled (CometProject, CometFilter) +Query: q35a-v2.7: ExplainInfo: +BroadcastHashJoin is not supported +Not all subqueries for union are native + +Query: q36a-v2.7. Comet Exec: Enabled (CometProject, CometFilter) +Query: q36a-v2.7: ExplainInfo: +BroadcastHashJoin is not supported +Not all subqueries for union are native +Window is not supported + +Query: q47-v2.7. Comet Exec: Enabled (CometProject, CometFilter) +Query: q47-v2.7: ExplainInfo: +BroadcastHashJoin is not supported +Window is not supported + +Query: q49-v2.7. Comet Exec: Enabled (CometProject, CometFilter) +Query: q49-v2.7: ExplainInfo: +BroadcastHashJoin is not supported +Window is not supported +Not all subqueries for union are native + +Query: q51a-v2.7. Comet Exec: Enabled (CometProject, CometFilter) +Query: q51a-v2.7: ExplainInfo: +BroadcastHashJoin is not supported +Window is not supported +SortMergeJoin is not supported + +Query: q57-v2.7. Comet Exec: Enabled (CometProject, CometFilter) +Query: q57-v2.7: ExplainInfo: +BroadcastHashJoin is not supported +Window is not supported + +Query: q64-v2.7. Comet Exec: Enabled (CometProject, CometFilter) +Query: q64-v2.7: ExplainInfo: +BroadcastHashJoin is not supported +Shuffle: unsupported Spark partitioning: org.apache.spark.sql.catalyst.plans.physical.RangePartitioning + +Query: q67a-v2.7. Comet Exec: Enabled (CometProject, CometFilter) +Query: q67a-v2.7: ExplainInfo: +BroadcastHashJoin is not supported +Not all subqueries for union are native +Window is not supported + +Query: q70a-v2.7. Comet Exec: Enabled (CometProject, CometFilter) +Query: q70a-v2.7: ExplainInfo: +BroadcastHashJoin is not supported +Window is not supported +Not all subqueries for union are native + +Query: q72-v2.7. Comet Exec: Enabled (CometProject, CometFilter) +Query: q72-v2.7: ExplainInfo: +BroadcastHashJoin is not supported + +Query: q74-v2.7. Comet Exec: Enabled (CometProject, CometFilter) +Query: q74-v2.7: ExplainInfo: +BroadcastHashJoin is not supported + +Query: q75-v2.7. Comet Exec: Enabled (CometProject, CometFilter) +Query: q75-v2.7: ExplainInfo: +BroadcastHashJoin is not supported +Not all subqueries for union are native + +Query: q77a-v2.7. Comet Exec: Enabled (CometProject, CometFilter) +Query: q77a-v2.7: ExplainInfo: +BroadcastHashJoin is not supported +BroadcastNestedLoopJoin is not supported +Not all subqueries for union are native + +Query: q78-v2.7. Comet Exec: Enabled (CometFilter) +Query: q78-v2.7: ExplainInfo: +BroadcastHashJoin is not supported + +Query: q80a-v2.7. Comet Exec: Enabled (CometProject, CometFilter) +Query: q80a-v2.7: ExplainInfo: +BroadcastHashJoin is not supported +Not all subqueries for union are native + +Query: q86a-v2.7. Comet Exec: Enabled (CometProject, CometFilter) +Query: q86a-v2.7: ExplainInfo: +BroadcastHashJoin is not supported +Not all subqueries for union are native +Window is not supported + +Query: q98-v2.7. Comet Exec: Enabled (CometProject, CometFilter) +Query: q98-v2.7: ExplainInfo: +BroadcastHashJoin is not supported +Window is not supported +Shuffle: unsupported Spark partitioning: org.apache.spark.sql.catalyst.plans.physical.RangePartitioning diff --git a/spark/inspections/CometTPCHQueriesList-results.txt b/spark/inspections/CometTPCHQueriesList-results.txt new file mode 100644 index 0000000000..c9bc6ce6de --- /dev/null +++ b/spark/inspections/CometTPCHQueriesList-results.txt @@ -0,0 +1,121 @@ +Query: q1 TPCH Snappy. Comet Exec: Enabled (CometHashAggregate, CometProject) +Query: q1 TPCH Snappy: ExplainInfo: +Shuffle: unsupported Spark partitioning: org.apache.spark.sql.catalyst.plans.physical.RangePartitioning + +Query: q2 TPCH Snappy. Comet Exec: Enabled (CometSort, CometSortMergeJoin, CometProject, CometFilter) +Query: q2 TPCH Snappy: ExplainInfo: +ObjectHashAggregate is not supported +might_contain is not supported +BroadcastHashJoin is not supported +SortMergeJoin is not supported + +Query: q3 TPCH Snappy. Comet Exec: Enabled (CometHashAggregate, CometSort, CometTakeOrderedAndProject, CometSortMergeJoin, CometProject, CometFilter) +Query: q3 TPCH Snappy: ExplainInfo: + + +Query: q4 TPCH Snappy. Comet Exec: Enabled (CometHashAggregate, CometSort, CometSortMergeJoin, CometProject, CometFilter) +Query: q4 TPCH Snappy: ExplainInfo: +Shuffle: unsupported Spark partitioning: org.apache.spark.sql.catalyst.plans.physical.RangePartitioning + +Query: q5 TPCH Snappy. Comet Exec: Enabled (CometSort, CometSortMergeJoin, CometProject, CometFilter) +Query: q5 TPCH Snappy: ExplainInfo: +ObjectHashAggregate is not supported +might_contain is not supported +BroadcastHashJoin is not supported +Shuffle: unsupported Spark partitioning: org.apache.spark.sql.catalyst.plans.physical.RangePartitioning + +Query: q6 TPCH Snappy. Comet Exec: Enabled (CometHashAggregate, CometProject, CometFilter) +Query: q6 TPCH Snappy: ExplainInfo: + + +Query: q7 TPCH Snappy. Comet Exec: Enabled (CometSort, CometProject, CometFilter) +Query: q7 TPCH Snappy: ExplainInfo: +ObjectHashAggregate is not supported +might_contain is not supported +SortMergeJoin is not supported +BroadcastHashJoin is not supported +Shuffle: unsupported Spark partitioning: org.apache.spark.sql.catalyst.plans.physical.RangePartitioning + +Query: q8 TPCH Snappy. Comet Exec: Enabled (CometSort, CometSortMergeJoin, CometProject, CometFilter) +Query: q8 TPCH Snappy: ExplainInfo: +ObjectHashAggregate is not supported +might_contain is not supported +BroadcastHashJoin is not supported +Shuffle: unsupported Spark partitioning: org.apache.spark.sql.catalyst.plans.physical.RangePartitioning + +Query: q9 TPCH Snappy. Comet Exec: Enabled (CometSort, CometSortMergeJoin, CometProject, CometFilter) +Query: q9 TPCH Snappy: ExplainInfo: +BroadcastHashJoin is not supported +Shuffle: unsupported Spark partitioning: org.apache.spark.sql.catalyst.plans.physical.RangePartitioning + +Query: q10 TPCH Snappy. Comet Exec: Enabled (CometSort, CometSortMergeJoin, CometProject, CometFilter) +Query: q10 TPCH Snappy: ExplainInfo: +BroadcastHashJoin is not supported + +Query: q11 TPCH Snappy. Comet Exec: Enabled (CometSort, CometProject, CometFilter) +Query: q11 TPCH Snappy: ExplainInfo: +ObjectHashAggregate is not supported +might_contain is not supported +SortMergeJoin is not supported +BroadcastHashJoin is not supported +Shuffle: unsupported Spark partitioning: org.apache.spark.sql.catalyst.plans.physical.RangePartitioning + +Query: q12 TPCH Snappy. Comet Exec: Enabled (CometHashAggregate, CometSort, CometSortMergeJoin, CometProject, CometFilter) +Query: q12 TPCH Snappy: ExplainInfo: +Shuffle: unsupported Spark partitioning: org.apache.spark.sql.catalyst.plans.physical.RangePartitioning + +Query: q13 TPCH Snappy. Comet Exec: Enabled (CometHashAggregate, CometSort, CometSortMergeJoin, CometProject, CometFilter) +Query: q13 TPCH Snappy: ExplainInfo: +Shuffle: unsupported Spark partitioning: org.apache.spark.sql.catalyst.plans.physical.RangePartitioning + +Query: q14 TPCH Snappy. Comet Exec: Enabled (CometHashAggregate, CometSort, CometSortMergeJoin, CometProject, CometFilter) +Query: q14 TPCH Snappy: ExplainInfo: + + +Query: q15 TPCH Snappy. Comet Exec: Enabled (CometHashAggregate, CometSort, CometSortMergeJoin, CometProject, CometFilter) +Query: q15 TPCH Snappy: ExplainInfo: +Shuffle: unsupported Spark partitioning: org.apache.spark.sql.catalyst.plans.physical.RangePartitioning + +Query: q16 TPCH Snappy. Comet Exec: Enabled (CometSort, CometProject, CometFilter) +Query: q16 TPCH Snappy: ExplainInfo: +BroadcastHashJoin is not supported +SortMergeJoin is not supported +Shuffle: unsupported Spark partitioning: org.apache.spark.sql.catalyst.plans.physical.RangePartitioning + +Query: q17 TPCH Snappy. Comet Exec: Enabled (CometHashAggregate, CometSort, CometSortMergeJoin, CometProject, CometFilter) +Query: q17 TPCH Snappy: ExplainInfo: +Sort merge join with a join condition is not supported + +Query: q18 TPCH Snappy. Comet Exec: Enabled (CometHashAggregate, CometSort, CometTakeOrderedAndProject, CometSortMergeJoin, CometProject, CometFilter) +Query: q18 TPCH Snappy: ExplainInfo: + + +Query: q19 TPCH Snappy. Comet Exec: Enabled (CometSort, CometProject, CometFilter) +Query: q19 TPCH Snappy: ExplainInfo: +Sort merge join with a join condition is not supported + +Query: q20 TPCH Snappy. Comet Exec: Enabled (CometHashAggregate, CometSort, CometSortMergeJoin, CometProject, CometFilter) +Query: q20 TPCH Snappy: ExplainInfo: +ObjectHashAggregate is not supported +Sort merge join with a join condition is not supported +might_contain is not supported +SortMergeJoin is not supported +BroadcastHashJoin is not supported +Shuffle: unsupported Spark partitioning: org.apache.spark.sql.catalyst.plans.physical.RangePartitioning + +Query: q21 TPCH Snappy. Comet Exec: Enabled (CometSort, CometProject, CometFilter) +Query: q21 TPCH Snappy: ExplainInfo: +ObjectHashAggregate is not supported +Sort merge join with a join condition is not supported +SortMergeJoin is not supported +might_contain is not supported +BroadcastHashJoin is not supported + +Query: q22 TPCH Snappy. Comet Exec: Enabled (CometHashAggregate, CometSort, CometSortMergeJoin, CometProject, CometFilter) +Query: q22 TPCH Snappy: ExplainInfo: +Shuffle: unsupported Spark partitioning: org.apache.spark.sql.catalyst.plans.physical.RangePartitioning + +Query: q1 TPCH Extended Snappy. Comet Exec: Enabled (CometHashAggregate, CometProject, CometFilter) +Query: q1 TPCH Extended Snappy: ExplainInfo: + + diff --git a/spark/src/main/scala/org/apache/comet/CometExplainInfo.scala b/spark/src/main/scala/org/apache/comet/CometExplainInfo.scala new file mode 100644 index 0000000000..eb36a2d176 --- /dev/null +++ b/spark/src/main/scala/org/apache/comet/CometExplainInfo.scala @@ -0,0 +1,81 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.comet + +import scala.collection.mutable + +import org.apache.spark.sql.catalyst.trees.TreeNodeTag + +class CometExplainInfo extends Serializable { + var info: String = "" + var children: Seq[CometExplainInfo] = Seq.empty + + override def toString: String = + if (children.isEmpty) info else s"$info(${children.mkString("", ", ", "")})" + + def toSimpleString: String = { + val sb = mutable.Set[String]() + dedup(sb) + sb.mkString("\n") + } + + private def dedup(all: mutable.Set[String]): mutable.Set[String] = { + if (children.isEmpty) { + all += info + } else { + children + .filter(o => (o != CometExplainInfo.none && o != CometExplainInfo.subTreeIsNotNative)) + .map(c => c.dedup(all)) + // return only the child node. Parent nodes clutter up the displayed info and in practice + // do not have very useful information. + all + } + } +} + +case class CometExplainSubtreeIsNotNative() extends CometExplainInfo + +object CometExplainInfo { + val none: CometExplainInfo = null + val subTreeIsNotNative: CometExplainSubtreeIsNotNative = CometExplainSubtreeIsNotNative() + val EXTENSION_INFO = new TreeNodeTag[String]("CometExtensionInfo") + + def apply(info: String): CometExplainInfo = { + val b = new CometExplainInfo + b.info = info + b + } + + def apply(info: String, child: CometExplainInfo): CometExplainInfo = { + val b = new CometExplainInfo + b.info = info + if (child != null) { + b.children = Seq(child) + } + b + } + + def apply(info: String, children: Seq[CometExplainInfo]): CometExplainInfo = { + val b = new CometExplainInfo + b.info = info + b.children = children + b + } +} diff --git a/spark/src/main/scala/org/apache/comet/CometSparkSessionExtensions.scala b/spark/src/main/scala/org/apache/comet/CometSparkSessionExtensions.scala index a10ac573ef..594b918fa7 100644 --- a/spark/src/main/scala/org/apache/comet/CometSparkSessionExtensions.scala +++ b/spark/src/main/scala/org/apache/comet/CometSparkSessionExtensions.scala @@ -24,8 +24,8 @@ import java.nio.ByteOrder import org.apache.spark.SparkConf import org.apache.spark.internal.Logging import org.apache.spark.network.util.ByteUnit -import org.apache.spark.sql.SparkSession -import org.apache.spark.sql.SparkSessionExtensions +import org.apache.spark.sql.{SparkSession, SparkSessionExtensions} +import org.apache.spark.sql.catalyst.expressions.AttributeReference import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.comet._ import org.apache.spark.sql.comet.execution.shuffle.{CometColumnarShuffle, CometNativeShuffle} @@ -44,7 +44,7 @@ import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ import org.apache.comet.CometConf._ -import org.apache.comet.CometSparkSessionExtensions.{isANSIEnabled, isCometBroadCastForceEnabled, isCometColumnarShuffleEnabled, isCometEnabled, isCometExecEnabled, isCometOperatorEnabled, isCometScan, isCometScanEnabled, isCometShuffleEnabled, isSchemaSupported, shouldApplyRowToColumnar} +import org.apache.comet.CometSparkSessionExtensions.{isANSIEnabled, isCometBroadCastForceEnabled, isCometColumnarShuffleEnabled, isCometEnabled, isCometExecEnabled, isCometOperatorEnabled, isCometScan, isCometScanEnabled, isCometShuffleEnabled, isSchemaSupported, opWithInfo, shouldApplyRowToColumnar} import org.apache.comet.parquet.{CometParquetScan, SupportsComet} import org.apache.comet.serde.OperatorOuterClass.Operator import org.apache.comet.serde.QueryPlanSerde @@ -74,8 +74,16 @@ class CometSparkSessionExtensions case class CometScanRule(session: SparkSession) extends Rule[SparkPlan] { override def apply(plan: SparkPlan): SparkPlan = { - if (!isCometEnabled(conf) || !isCometScanEnabled(conf)) plan - else { + if (!isCometEnabled(conf) || !isCometScanEnabled(conf)) { + val info = if (!isCometEnabled(conf)) { + CometExplainInfo("Comet is not enabled") + } else if (!isCometScanEnabled(conf)) { + CometExplainInfo("Comet Scan is not enabled") + } else { + CometExplainInfo.none + } + opWithInfo(plan, info) + } else { plan.transform { // data source V2 case scanExec: BatchScanExec @@ -90,6 +98,32 @@ class CometSparkSessionExtensions scanExec.copy(scan = cometScan), runtimeFilters = scanExec.runtimeFilters) + // unsupported parquet data source V2 + case scanExec: BatchScanExec if scanExec.scan.isInstanceOf[ParquetScan] => + val requiredSchema = scanExec.scan.asInstanceOf[ParquetScan].readDataSchema + val info1 = if (isSchemaSupported(requiredSchema)) { + CometExplainInfo(s"Schema $requiredSchema is not supported") + } else { + CometExplainInfo.none + } + val readPartitionSchema = scanExec.scan.asInstanceOf[ParquetScan].readPartitionSchema + val info2 = if (isSchemaSupported(readPartitionSchema)) { + CometExplainInfo(s"Schema $readPartitionSchema is not supported") + } else { + CometExplainInfo.none + } + // Comet does not support pushedAggregate + val info3 = + if (!getPushedAggregate(scanExec.scan.asInstanceOf[ParquetScan]).isEmpty) { + CometExplainInfo("Comet does not support pushed aggregate") + } else { + CometExplainInfo.none + } + opWithInfo(scanExec, CometExplainInfo("SCAN", Seq(info1, info2, info3))) + + case scanExec: BatchScanExec if !scanExec.scan.isInstanceOf[ParquetScan] => + opWithInfo(scanExec, CometExplainInfo("Comet Scan only supports Parquet")) + // iceberg scan case scanExec: BatchScanExec => if (isSchemaSupported(scanExec.scan.readSchema())) { @@ -102,22 +136,22 @@ class CometSparkSessionExtensions scanExec.clone().asInstanceOf[BatchScanExec], runtimeFilters = scanExec.runtimeFilters) case _ => - logInfo( - "Comet extension is not enabled for " + - s"${scanExec.scan.getClass.getSimpleName}: not enabled on data source side") - scanExec + val msg = "Comet extension is not enabled for " + + s"${scanExec.scan.getClass.getSimpleName}: not enabled on data source side" + logInfo(msg) + opWithInfo(scanExec, CometExplainInfo(msg)) } } else { - logInfo( - "Comet extension is not enabled for " + - s"${scanExec.scan.getClass.getSimpleName}: Schema not supported") - scanExec + val msg = "Comet extension is not enabled for " + + s"${scanExec.scan.getClass.getSimpleName}: Schema not supported" + logInfo(msg) + opWithInfo(scanExec, CometExplainInfo(msg)) } // data source V1 case scanExec @ FileSourceScanExec( HadoopFsRelation(_, partitionSchema, _, _, _: ParquetFileFormat, _), - _: Seq[_], + _: Seq[AttributeReference], requiredSchema, _, _, @@ -127,6 +161,29 @@ class CometSparkSessionExtensions _) if isSchemaSupported(requiredSchema) && isSchemaSupported(partitionSchema) => logInfo("Comet extension enabled for v1 Scan") CometScanExec(scanExec, session) + + // data source v1 not supported case + case scanExec @ FileSourceScanExec( + HadoopFsRelation(_, partitionSchema, _, _, _: ParquetFileFormat, _), + _: Seq[AttributeReference], + requiredSchema, + _, + _, + _, + _, + _, + _) => + val info1 = if (!isSchemaSupported(requiredSchema)) { + CometExplainInfo(s"Schema $requiredSchema is not supported") + } else { + CometExplainInfo.none + } + val info2 = if (!isSchemaSupported(partitionSchema)) { + CometExplainInfo(s"Schema $partitionSchema is not supported") + } else { + CometExplainInfo.none + } + opWithInfo(scanExec, CometExplainInfo("SCAN", Seq(info1, info2))) } } } @@ -137,7 +194,7 @@ class CometSparkSessionExtensions plan.transformUp { case s: ShuffleExchangeExec if isCometPlan(s.child) && !isCometColumnarShuffleEnabled(conf) && - QueryPlanSerde.supportPartitioning(s.child.output, s.outputPartitioning) => + QueryPlanSerde.supportPartitioning(s.child.output, s.outputPartitioning)._1 => logInfo("Comet extension enabled for Native Shuffle") // Switch to use Decimal128 regardless of precision, since Arrow native execution @@ -145,12 +202,11 @@ class CometSparkSessionExtensions conf.setConfString(CometConf.COMET_USE_DECIMAL_128.key, "true") CometShuffleExchangeExec(s, shuffleType = CometNativeShuffle) - // Columnar shuffle for regular Spark operators (not Comet) and Comet operators - // (if configured) + // Arrow shuffle for regular Spark operators (not Comet) and Comet operators (if configured) case s: ShuffleExchangeExec if (!s.child.supportsColumnar || isCometPlan( s.child)) && isCometColumnarShuffleEnabled(conf) && - QueryPlanSerde.supportPartitioningTypes(s.child.output) => + QueryPlanSerde.supportPartitioningTypes(s.child.output, s.outputPartitioning)._1 => logInfo("Comet extension enabled for JVM Columnar Shuffle") CometShuffleExchangeExec(s, shuffleType = CometColumnarShuffle) } @@ -224,28 +280,37 @@ class CometSparkSessionExtensions */ // spotless:on private def transform(plan: SparkPlan): SparkPlan = { - def transform1(op: SparkPlan): Option[Operator] = { - if (op.children.forall(_.isInstanceOf[CometNativeExec])) { - QueryPlanSerde.operator2Proto( - op, - op.children.map(_.asInstanceOf[CometNativeExec].nativeOp): _*) + def transform1(op: SparkPlan): (Option[Operator], CometExplainInfo) = { + val allNativeExec = op.children.map { + case childNativeOp: CometNativeExec => Some(childNativeOp.nativeOp) + case _ => None + } + + if (allNativeExec.forall(_.isDefined)) { + QueryPlanSerde.operator2Proto(op, allNativeExec.map(_.get): _*) } else { - None + QueryPlanSerde.unsupported(op.nodeName, CometExplainInfo.subTreeIsNotNative) } } plan.transformUp { case op if isCometScan(op) => - val nativeOp = QueryPlanSerde.operator2Proto(op).get - CometScanWrapper(nativeOp, op) + val (nativeOp, info) = QueryPlanSerde.operator2Proto(op) + nativeOp match { + case Some(scanOp) => CometScanWrapper(scanOp, op) + case None => opWithInfo(op, info) + } case op if shouldApplyRowToColumnar(conf, op) => val cometOp = CometRowToColumnarExec(op) - val nativeOp = QueryPlanSerde.operator2Proto(cometOp).get - CometScanWrapper(nativeOp, cometOp) + val (nativeOp, info) = QueryPlanSerde.operator2Proto(cometOp) + nativeOp match { + case Some(scanOp) => CometScanWrapper(scanOp, op) + case None => opWithInfo(op, info) + } case op: ProjectExec => - val newOp = transform1(op) + val (newOp, info) = transform1(op) newOp match { case Some(nativeOp) => CometProjectExec( @@ -256,38 +321,38 @@ class CometSparkSessionExtensions op.child, SerializedPlan(None)) case None => - op + opWithInfo(op, info) } case op: FilterExec => - val newOp = transform1(op) + val (newOp, info) = transform1(op) newOp match { case Some(nativeOp) => CometFilterExec(nativeOp, op, op.condition, op.child, SerializedPlan(None)) case None => - op + opWithInfo(op, info) } case op: SortExec => - val newOp = transform1(op) + val (newOp, info) = transform1(op) newOp match { case Some(nativeOp) => CometSortExec(nativeOp, op, op.sortOrder, op.child, SerializedPlan(None)) case None => - op + opWithInfo(op, info) } case op: LocalLimitExec => - val newOp = transform1(op) + val (newOp, info) = transform1(op) newOp match { case Some(nativeOp) => CometLocalLimitExec(nativeOp, op, op.limit, op.child, SerializedPlan(None)) case None => - op + opWithInfo(op, info) } case op: GlobalLimitExec => - val newOp = transform1(op) + val (newOp, info) = transform1(op) newOp match { case Some(nativeOp) => CometGlobalLimitExec(nativeOp, op, op.limit, op.child, SerializedPlan(None)) @@ -299,27 +364,28 @@ class CometSparkSessionExtensions if isCometNative(op.child) && isCometOperatorEnabled(conf, "collectLimit") && isCometShuffleEnabled(conf) && getOffset(op) == 0 => - QueryPlanSerde.operator2Proto(op) match { + val (newOp, info) = QueryPlanSerde.operator2Proto(op) + newOp match { case Some(nativeOp) => val offset = getOffset(op) val cometOp = CometCollectLimitExec(op, op.limit, offset, op.child) CometSinkPlaceHolder(nativeOp, op, cometOp) case None => - op + opWithInfo(op, info) } case op: ExpandExec => - val newOp = transform1(op) + val (newOp, info) = transform1(op) newOp match { case Some(nativeOp) => CometExpandExec(nativeOp, op, op.projections, op.child, SerializedPlan(None)) case None => - op + opWithInfo(op, info) } case op @ HashAggregateExec(_, _, _, groupingExprs, aggExprs, _, _, _, child) => - val newOp = transform1(op) + val (newOp, info) = transform1(op) newOp match { case Some(nativeOp) => val modes = aggExprs.map(_.mode).distinct @@ -338,13 +404,13 @@ class CometSparkSessionExtensions child, SerializedPlan(None)) case None => - op + opWithInfo(op, info) } case op: ShuffledHashJoinExec if isCometOperatorEnabled(conf, "hash_join") && op.children.forall(isCometNative(_)) => - val newOp = transform1(op) + val (newOp, info) = transform1(op) newOp match { case Some(nativeOp) => CometHashJoinExec( @@ -359,13 +425,13 @@ class CometSparkSessionExtensions op.right, SerializedPlan(None)) case None => - op + opWithInfo(op, info) } case op: BroadcastHashJoinExec if isCometOperatorEnabled(conf, "broadcast_hash_join") && op.children.forall(isCometNative(_)) => - val newOp = transform1(op) + val (newOp, info) = transform1(op) newOp match { case Some(nativeOp) => CometBroadcastHashJoinExec( @@ -380,13 +446,13 @@ class CometSparkSessionExtensions op.right, SerializedPlan(None)) case None => - op + opWithInfo(op, info) } case op: SortMergeJoinExec if isCometOperatorEnabled(conf, "sort_merge_join") && op.children.forall(isCometNative(_)) => - val newOp = transform1(op) + val (newOp, info) = transform1(op) newOp match { case Some(nativeOp) => CometSortMergeJoinExec( @@ -400,52 +466,98 @@ class CometSparkSessionExtensions op.right, SerializedPlan(None)) case None => - op + opWithInfo(op, info) } case c @ CoalesceExec(numPartitions, child) if isCometOperatorEnabled(conf, "coalesce") && isCometNative(child) => - QueryPlanSerde.operator2Proto(c) match { + val (newOp, info) = QueryPlanSerde.operator2Proto(c) + newOp match { case Some(nativeOp) => val cometOp = CometCoalesceExec(c, numPartitions, child) CometSinkPlaceHolder(nativeOp, c, cometOp) case None => - c + opWithInfo(c, info) } case s: TakeOrderedAndProjectExec if isCometNative(s.child) && isCometOperatorEnabled(conf, "takeOrderedAndProjectExec") && isCometShuffleEnabled(conf) && - CometTakeOrderedAndProjectExec.isSupported(s) => - QueryPlanSerde.operator2Proto(s) match { + CometTakeOrderedAndProjectExec + .isSupported(s) + ._1 => + // TODO: support offset for Spark 3.4 + val (newOp, info) = QueryPlanSerde.operator2Proto(s) + newOp match { case Some(nativeOp) => val cometOp = CometTakeOrderedAndProjectExec(s, s.limit, s.sortOrder, s.projectList, s.child) CometSinkPlaceHolder(nativeOp, s, cometOp) case None => - s + opWithInfo(s, info) } + case s: TakeOrderedAndProjectExec => + val msg1 = + if (!isCometNative(s.child)) { + CometExplainInfo.subTreeIsNotNative + } else { + CometExplainInfo.none + } + val msg2 = + if (!isCometOperatorEnabled(conf, "takeOrderedAndProjectExec")) { + CometExplainInfo("TakeOrderedAndProject is not enabled") + } else { + CometExplainInfo.none + } + val (isTakeOrderedAndProjectSupported, notSupportedReason) = + CometTakeOrderedAndProjectExec.isSupported(s) + val msg3 = + if (!isTakeOrderedAndProjectSupported) { + notSupportedReason + } else { + CometExplainInfo.none + } + val msg4 = + if (!isCometShuffleEnabled(conf)) { + CometExplainInfo("TakeOrderedAndProject requires shuffle to be enabled") + } else { + CometExplainInfo.none + } + opWithInfo(s, CometExplainInfo("TakeOrderedAndProject", Seq(msg1, msg2, msg3, msg4))) + case u: UnionExec if isCometOperatorEnabled(conf, "union") && u.children.forall(isCometNative) => - QueryPlanSerde.operator2Proto(u) match { + val (newOp, info) = QueryPlanSerde.operator2Proto(u) + newOp match { case Some(nativeOp) => val cometOp = CometUnionExec(u, u.children) CometSinkPlaceHolder(nativeOp, u, cometOp) - case None => - u + case None => opWithInfo(u, info) } - // For AQE broadcast stage on a Comet broadcast exchange + case u: UnionExec => + val msg1 = if (!isCometOperatorEnabled(conf, "union")) { + CometExplainInfo("Union is not enabled") + } else { + CometExplainInfo.none + } + val msg2 = if (!u.children.forall(isCometNative)) { + CometExplainInfo("Not all subqueries for union are native") + } else { + CometExplainInfo.none + } + opWithInfo(u, CometExplainInfo("Union", Seq(msg1, msg2))) + case s @ BroadcastQueryStageExec(_, _: CometBroadcastExchangeExec, _) => - val newOp = transform1(s) + val (newOp, info) = transform1(s) newOp match { case Some(nativeOp) => CometSinkPlaceHolder(nativeOp, s, s) case None => - s + opWithInfo(s, info) } // `CometBroadcastExchangeExec`'s broadcast output is not compatible with Spark's broadcast @@ -457,11 +569,12 @@ class CometSparkSessionExtensions case b: BroadcastExchangeExec if isCometNative(b.child) && isCometOperatorEnabled(conf, "broadcastExchangeExec") => - QueryPlanSerde.operator2Proto(b) match { + val (newOp, info) = QueryPlanSerde.operator2Proto(b) + newOp match { case Some(nativeOp) => val cometOp = CometBroadcastExchangeExec(b, b.child) CometSinkPlaceHolder(nativeOp, b, cometOp) - case None => b + case None => opWithInfo(b, info) } case other => other } @@ -470,7 +583,15 @@ class CometSparkSessionExtensions if (isCometNative(newPlan) || isCometBroadCastForceEnabled(conf)) { newPlan } else { - plan + val msg = + if (!isCometOperatorEnabled( + conf, + "broadcastExchangeExec") || !isCometBroadCastForceEnabled(conf)) { + CometExplainInfo("Native Broadcast is not enabled") + } else { + CometExplainInfo.none + } + opWithInfo(plan, msg) } } else { plan @@ -478,12 +599,12 @@ class CometSparkSessionExtensions // For AQE shuffle stage on a Comet shuffle exchange case s @ ShuffleQueryStageExec(_, _: CometShuffleExchangeExec, _) => - val newOp = transform1(s) + val (newOp, info) = transform1(s) newOp match { case Some(nativeOp) => CometSinkPlaceHolder(nativeOp, s, s) case None => - s + opWithInfo(s, info) } // For AQE shuffle stage on a reused Comet shuffle exchange @@ -493,22 +614,22 @@ class CometSparkSessionExtensions _, ReusedExchangeExec(_, _: CometShuffleExchangeExec), _) => - val newOp = transform1(s) + val (newOp, info) = transform1(s) newOp match { case Some(nativeOp) => CometSinkPlaceHolder(nativeOp, s, s) case None => - s + opWithInfo(s, info) } // Native shuffle for Comet operators case s: ShuffleExchangeExec if isCometShuffleEnabled(conf) && !isCometColumnarShuffleEnabled(conf) && - QueryPlanSerde.supportPartitioning(s.child.output, s.outputPartitioning) => + QueryPlanSerde.supportPartitioning(s.child.output, s.outputPartitioning)._1 => logInfo("Comet extension enabled for Native Shuffle") - val newOp = transform1(s) + val (newOp, info) = transform1(s) newOp match { case Some(nativeOp) => // Switch to use Decimal128 regardless of precision, since Arrow native execution @@ -517,17 +638,17 @@ class CometSparkSessionExtensions val cometOp = CometShuffleExchangeExec(s, shuffleType = CometNativeShuffle) CometSinkPlaceHolder(nativeOp, s, cometOp) case None => - s + opWithInfo(s, info) } - // Columnar shuffle for regular Spark operators (not Comet) and Comet operators + // Arrow shuffle for regular Spark operators (not Comet) and Comet operators // (if configured) case s: ShuffleExchangeExec if isCometShuffleEnabled(conf) && isCometColumnarShuffleEnabled(conf) && - QueryPlanSerde.supportPartitioningTypes(s.child.output) => + QueryPlanSerde.supportPartitioningTypes(s.child.output, s.outputPartitioning)._1 => logInfo("Comet extension enabled for JVM Columnar Shuffle") - val newOp = QueryPlanSerde.operator2Proto(s) + val (newOp, info) = QueryPlanSerde.operator2Proto(s) newOp match { case Some(nativeOp) => s.child match { @@ -538,12 +659,47 @@ class CometSparkSessionExtensions s } case None => - s + opWithInfo(s, info) } + case s: ShuffleExchangeExec => + val isShuffleEnabled = isCometShuffleEnabled(conf) + val msg1 = if (!isShuffleEnabled) { + CometExplainInfo("Native shuffle is not enabled") + } else { + CometExplainInfo.none + } + val columnarShuffleEnabled = isCometColumnarShuffleEnabled(conf) + val msg2 = + if (isShuffleEnabled && !columnarShuffleEnabled && !QueryPlanSerde + .supportPartitioning(s.child.output, s.outputPartitioning) + ._1) { + CometExplainInfo("Shuffle: " + + s"${QueryPlanSerde.supportPartitioning(s.child.output, s.outputPartitioning)._2}") + } else { + CometExplainInfo.none + } + val msg3 = + if (isShuffleEnabled && columnarShuffleEnabled && !QueryPlanSerde + .supportPartitioningTypes(s.child.output, s.outputPartitioning) + ._1) { + val info = + QueryPlanSerde.supportPartitioningTypes(s.child.output, s.outputPartitioning)._2 + CometExplainInfo(s"Columnar shuffle: $info") + } else { + CometExplainInfo.none + } + opWithInfo(s, CometExplainInfo("BroadcastExchange", Seq(msg1, msg2, msg3))) + case op => // An operator that is not supported by Comet - op + op match { + case b: CometExec => b + case b: CometBroadcastExchangeExec => b + case b: CometShuffleExchangeExec => b + case o => + opWithInfo(o, CometExplainInfo(s"${o.nodeName} is not supported")) + } } } @@ -802,4 +958,12 @@ object CometSparkSessionExtensions extends Logging { ByteUnit.MiB.toBytes(shuffleMemorySize) } } + + def opWithInfo(op: SparkPlan, info: CometExplainInfo): SparkPlan = { + val simpleStr = info.toSimpleString + if (simpleStr.nonEmpty) { + op.setTagValue(CometExplainInfo.EXTENSION_INFO, simpleStr) + } + op + } } diff --git a/spark/src/main/scala/org/apache/comet/ExtendedExplainInfo.scala b/spark/src/main/scala/org/apache/comet/ExtendedExplainInfo.scala new file mode 100644 index 0000000000..a36fc52070 --- /dev/null +++ b/spark/src/main/scala/org/apache/comet/ExtendedExplainInfo.scala @@ -0,0 +1,85 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.comet + +import scala.collection.mutable + +import org.apache.spark.sql.ExtendedExplainGenerator +import org.apache.spark.sql.execution.{InputAdapter, SparkPlan, WholeStageCodegenExec} +import org.apache.spark.sql.execution.adaptive.{AdaptiveSparkPlanExec, QueryStageExec} + +class ExtendedExplainInfo extends ExtendedExplainGenerator { + + override def title: String = "Comet" + + override def generateExtendedInfo(plan: SparkPlan): String = { + val info = extensionInfo(plan) + info.distinct.mkString("\n") + } + + private def getActualPlan(plan: SparkPlan): SparkPlan = { + plan match { + case p: AdaptiveSparkPlanExec => getActualPlan(p.executedPlan) + case p: InputAdapter => getActualPlan(p.child) + case p: QueryStageExec => getActualPlan(p.plan) + case p: WholeStageCodegenExec => getActualPlan(p.child) + case p => p + } + } + + private def extensionInfo(plan: SparkPlan): mutable.Seq[String] = { + var info = mutable.Seq[String]() + val sorted = sortup(plan) + sorted.foreach(p => { + val s = + getActualPlan(p).getTagValue(CometExplainInfo.EXTENSION_INFO).map(t => t).getOrElse("") + if (s.nonEmpty) { + info = info :+ s + } + }) + info + } + + // get all plan nodes, breadth first, leaf nodes first + private def sortup(plan: SparkPlan): mutable.Queue[SparkPlan] = { + val ordered = new mutable.Queue[SparkPlan]() + val traversed = mutable.Queue[SparkPlan](getActualPlan(plan)) + while (traversed.nonEmpty) { + val s = traversed.dequeue() + ordered += s + if (s.innerChildren.nonEmpty) { + s.innerChildren.foreach(c => { + c match { + case _: SparkPlan => traversed.enqueue(getActualPlan(c.asInstanceOf[SparkPlan])) + case _ => + } + () + }) + } + if (s.children.nonEmpty) { + s.children.foreach(c => { + traversed.enqueue(getActualPlan(c)) + () + }) + } + } + ordered.reverse + } +} diff --git a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala index 26fc708ffd..bdd0ad9d80 100644 --- a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala +++ b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala @@ -20,16 +20,17 @@ package org.apache.comet.serde import scala.collection.JavaConverters._ +import scala.reflect.ClassTag import org.apache.spark.internal.Logging import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, Average, BitAndAgg, BitOrAgg, BitXorAgg, Count, Final, First, Last, Max, Min, Partial, Sum} +import org.apache.spark.sql.catalyst.expressions.aggregate._ import org.apache.spark.sql.catalyst.expressions.objects.StaticInvoke import org.apache.spark.sql.catalyst.optimizer.{BuildRight, NormalizeNaNAndZero} import org.apache.spark.sql.catalyst.plans._ -import org.apache.spark.sql.catalyst.plans.physical.{HashPartitioning, Partitioning, SinglePartition} +import org.apache.spark.sql.catalyst.plans.physical._ import org.apache.spark.sql.catalyst.util.CharVarcharCodegenUtils -import org.apache.spark.sql.comet.{CometBroadcastExchangeExec, CometRowToColumnarExec, CometSinkPlaceHolder, DecimalPrecision} +import org.apache.spark.sql.comet.{CometBroadcastExchangeExec, CometHashAggregateExec, CometPlan, CometRowToColumnarExec, CometSinkPlaceHolder, DecimalPrecision} import org.apache.spark.sql.comet.execution.shuffle.CometShuffleExchangeExec import org.apache.spark.sql.execution import org.apache.spark.sql.execution._ @@ -41,6 +42,7 @@ import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String +import org.apache.comet.{CometExplainInfo, CometExplainSubtreeIsNotNative} import org.apache.comet.CometSparkSessionExtensions.{isCometOperatorEnabled, isCometScan, isSpark32, isSpark34Plus} import org.apache.comet.serde.ExprOuterClass.{AggExpr, DataType => ProtoDataType, Expr, ScalarFunc} import org.apache.comet.serde.ExprOuterClass.DataType.{DataTypeInfo, DecimalInfo, ListInfo, MapInfo, StructInfo} @@ -52,7 +54,24 @@ import org.apache.comet.shims.ShimQueryPlanSerde */ object QueryPlanSerde extends Logging with ShimQueryPlanSerde { def emitWarning(reason: String): Unit = { - logWarning(s"Comet native execution: $reason") + logWarning(s"Comet native execution is disabled due to: $reason") + } + + def unsupported[T: ClassTag, R: ClassTag]( + op: String, + info: T): (Option[R], CometExplainInfo) = { + info match { + case s: Seq[CometExplainInfo] if s.nonEmpty => + (None, CometExplainInfo(op, s)) + case _: CometExplainSubtreeIsNotNative => + (None, CometExplainInfo.subTreeIsNotNative) + case i: CometExplainInfo => + (None, CometExplainInfo(op, i)) + case s: String => + (None, CometExplainInfo(s"$op is not supported ($s)")) + case _ => + (None, CometExplainInfo(s"$op is not supported")) + } } def supportedDataType(dt: DataType): Boolean = dt match { @@ -200,10 +219,10 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde { def aggExprToProto( aggExpr: AggregateExpression, inputs: Seq[Attribute], - binding: Boolean): Option[AggExpr] = { + binding: Boolean): (Option[AggExpr], CometExplainInfo) = { aggExpr.aggregateFunction match { case s @ Sum(child, _) if sumDataTypeSupported(s.dataType) => - val childExpr = exprToProto(child, inputs, binding) + val (childExpr, info) = exprToProto(child, inputs, binding) val dataType = serializeDataType(s.dataType) if (childExpr.isDefined && dataType.isDefined) { @@ -212,16 +231,20 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde { sumBuilder.setDatatype(dataType.get) sumBuilder.setFailOnError(getFailOnError(s)) - Some( - ExprOuterClass.AggExpr - .newBuilder() - .setSum(sumBuilder) - .build()) + ( + Some( + ExprOuterClass.AggExpr + .newBuilder() + .setSum(sumBuilder) + .build()), + CometExplainInfo.none) + } else if (dataType.isEmpty) { + unsupported("SUM", CometExplainInfo(s"datatype ${s.dataType} is not supported")) } else { - None + unsupported("SUM", info) } case s @ Average(child, _) if avgDataTypeSupported(s.dataType) => - val childExpr = exprToProto(child, inputs, binding) + val (childExpr, info) = exprToProto(child, inputs, binding) val dataType = serializeDataType(s.dataType) val sumDataType = if (child.dataType.isInstanceOf[DecimalType]) { @@ -244,31 +267,37 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde { builder.setFailOnError(getFailOnError(s)) builder.setSumDatatype(sumDataType.get) - Some( - ExprOuterClass.AggExpr - .newBuilder() - .setAvg(builder) - .build()) + ( + Some( + ExprOuterClass.AggExpr + .newBuilder() + .setAvg(builder) + .build()), + CometExplainInfo.none) + } else if (dataType.isEmpty) { + unsupported("AVERAGE", CometExplainInfo(s"datatype ${s.dataType} is not supported")) } else { - None + unsupported("AVERAGE", info) } case Count(children) => - val exprChildren = children.map(exprToProto(_, inputs, binding)) + val (exprChildren, exprInfo) = children.map(exprToProto(_, inputs, binding)).unzip if (exprChildren.forall(_.isDefined)) { val countBuilder = ExprOuterClass.Count.newBuilder() countBuilder.addAllChildren(exprChildren.map(_.get).asJava) - Some( - ExprOuterClass.AggExpr - .newBuilder() - .setCount(countBuilder) - .build()) + ( + Some( + ExprOuterClass.AggExpr + .newBuilder() + .setCount(countBuilder) + .build()), + CometExplainInfo.none) } else { - None + unsupported("COUNT", exprInfo) } case min @ Min(child) if minMaxDataTypeSupported(min.dataType) => - val childExpr = exprToProto(child, inputs, binding) + val (childExpr, info) = exprToProto(child, inputs, binding) val dataType = serializeDataType(min.dataType) if (childExpr.isDefined && dataType.isDefined) { @@ -276,16 +305,20 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde { minBuilder.setChild(childExpr.get) minBuilder.setDatatype(dataType.get) - Some( - ExprOuterClass.AggExpr - .newBuilder() - .setMin(minBuilder) - .build()) + ( + Some( + ExprOuterClass.AggExpr + .newBuilder() + .setMin(minBuilder) + .build()), + CometExplainInfo.none) + } else if (dataType.isEmpty) { + unsupported("MIN", CometExplainInfo(s"datatype ${min.dataType} is not supported")) } else { - None + unsupported("MIN", info) } case max @ Max(child) if minMaxDataTypeSupported(max.dataType) => - val childExpr = exprToProto(child, inputs, binding) + val (childExpr, info) = exprToProto(child, inputs, binding) val dataType = serializeDataType(max.dataType) if (childExpr.isDefined && dataType.isDefined) { @@ -293,17 +326,21 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde { maxBuilder.setChild(childExpr.get) maxBuilder.setDatatype(dataType.get) - Some( - ExprOuterClass.AggExpr - .newBuilder() - .setMax(maxBuilder) - .build()) + ( + Some( + ExprOuterClass.AggExpr + .newBuilder() + .setMax(maxBuilder) + .build()), + CometExplainInfo.none) + } else if (dataType.isEmpty) { + unsupported("MAX", CometExplainInfo(s"datatype ${max.dataType} is not supported")) } else { - None + unsupported("MAX", info) } case first @ First(child, ignoreNulls) if !ignoreNulls => // DataFusion doesn't support ignoreNulls true - val childExpr = exprToProto(child, inputs, binding) + val (childExpr, info) = exprToProto(child, inputs, binding) val dataType = serializeDataType(first.dataType) if (childExpr.isDefined && dataType.isDefined) { @@ -311,17 +348,21 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde { firstBuilder.setChild(childExpr.get) firstBuilder.setDatatype(dataType.get) - Some( - ExprOuterClass.AggExpr - .newBuilder() - .setFirst(firstBuilder) - .build()) + ( + Some( + ExprOuterClass.AggExpr + .newBuilder() + .setFirst(firstBuilder) + .build()), + CometExplainInfo.none) + } else if (dataType.isEmpty) { + unsupported("FIRST", CometExplainInfo(s"datatype ${first.dataType} is not supported")) } else { - None + unsupported("FIRST", info) } case last @ Last(child, ignoreNulls) if !ignoreNulls => // DataFusion doesn't support ignoreNulls true - val childExpr = exprToProto(child, inputs, binding) + val (childExpr, info) = exprToProto(child, inputs, binding) val dataType = serializeDataType(last.dataType) if (childExpr.isDefined && dataType.isDefined) { @@ -329,16 +370,20 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde { lastBuilder.setChild(childExpr.get) lastBuilder.setDatatype(dataType.get) - Some( - ExprOuterClass.AggExpr - .newBuilder() - .setLast(lastBuilder) - .build()) + ( + Some( + ExprOuterClass.AggExpr + .newBuilder() + .setLast(lastBuilder) + .build()), + CometExplainInfo.none) + } else if (dataType.isEmpty) { + unsupported("LAST", CometExplainInfo(s"datatype ${last.dataType} is not supported")) } else { - None + unsupported("LAST", info) } case bitAnd @ BitAndAgg(child) if bitwiseAggTypeSupported(bitAnd.dataType) => - val childExpr = exprToProto(child, inputs, binding) + val (childExpr, info) = exprToProto(child, inputs, binding) val dataType = serializeDataType(bitAnd.dataType) if (childExpr.isDefined && dataType.isDefined) { @@ -346,16 +391,20 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde { bitAndBuilder.setChild(childExpr.get) bitAndBuilder.setDatatype(dataType.get) - Some( - ExprOuterClass.AggExpr - .newBuilder() - .setBitAndAgg(bitAndBuilder) - .build()) + ( + Some( + ExprOuterClass.AggExpr + .newBuilder() + .setBitAndAgg(bitAndBuilder) + .build()), + CometExplainInfo.none) + } else if (dataType.isEmpty) { + unsupported("BITAND", CometExplainInfo(s"datatype ${bitAnd.dataType} is not supported")) } else { - None + unsupported("BITAND", info) } case bitOr @ BitOrAgg(child) if bitwiseAggTypeSupported(bitOr.dataType) => - val childExpr = exprToProto(child, inputs, binding) + val (childExpr, info) = exprToProto(child, inputs, binding) val dataType = serializeDataType(bitOr.dataType) if (childExpr.isDefined && dataType.isDefined) { @@ -363,16 +412,20 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde { bitOrBuilder.setChild(childExpr.get) bitOrBuilder.setDatatype(dataType.get) - Some( - ExprOuterClass.AggExpr - .newBuilder() - .setBitOrAgg(bitOrBuilder) - .build()) + ( + Some( + ExprOuterClass.AggExpr + .newBuilder() + .setBitOrAgg(bitOrBuilder) + .build()), + CometExplainInfo.none) + } else if (dataType.isEmpty) { + unsupported("BITOR", CometExplainInfo(s"datatype ${bitOr.dataType} is not supported")) } else { - None + unsupported("BITOR", info) } case bitXor @ BitXorAgg(child) if bitwiseAggTypeSupported(bitXor.dataType) => - val childExpr = exprToProto(child, inputs, binding) + val (childExpr, info) = exprToProto(child, inputs, binding) val dataType = serializeDataType(bitXor.dataType) if (childExpr.isDefined && dataType.isDefined) { @@ -380,18 +433,22 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde { bitXorBuilder.setChild(childExpr.get) bitXorBuilder.setDatatype(dataType.get) - Some( - ExprOuterClass.AggExpr - .newBuilder() - .setBitXorAgg(bitXorBuilder) - .build()) + ( + Some( + ExprOuterClass.AggExpr + .newBuilder() + .setBitXorAgg(bitXorBuilder) + .build()), + CometExplainInfo.none) + } else if (dataType.isEmpty) { + unsupported("BITXOR", CometExplainInfo(s"datatype ${bitXor.dataType} is not supported")) } else { - None + unsupported("BITXOR", info) } case fn => emitWarning(s"unsupported Spark aggregate function: $fn") - None + unsupported(fn.prettyName, CometExplainInfo.none) } } @@ -410,11 +467,11 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde { def exprToProto( expr: Expression, input: Seq[Attribute], - binding: Boolean = true): Option[Expr] = { + binding: Boolean = true): (Option[Expr], CometExplainInfo) = { def castToProto( timeZoneId: Option[String], dt: DataType, - childExpr: Option[Expr]): Option[Expr] = { + childExpr: Option[Expr]): (Option[Expr], CometExplainInfo) = { val dataType = serializeDataType(dt) if (childExpr.isDefined && dataType.isDefined) { @@ -425,17 +482,25 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde { val timeZone = timeZoneId.getOrElse("UTC") castBuilder.setTimezone(timeZone) - Some( - ExprOuterClass.Expr - .newBuilder() - .setCast(castBuilder) - .build()) + ( + Some( + ExprOuterClass.Expr + .newBuilder() + .setCast(castBuilder) + .build()), + CometExplainInfo.none) } else { - None + if (!dataType.isDefined) { + unsupported("CAST", CometExplainInfo(s"Unsupported datatype ${dt}")) + } else { + unsupported("CAST", CometExplainInfo(s"Unsupported expression ${childExpr}")) + } } } - def exprToProtoInternal(expr: Expression, inputs: Seq[Attribute]): Option[Expr] = { + def exprToProtoInternal( + expr: Expression, + inputs: Seq[Attribute]): (Option[Expr], CometExplainInfo) = { SQLConf.get expr match { case a @ Alias(_, _) => @@ -447,12 +512,16 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde { exprToProtoInternal(Literal(value, dataType), inputs) case Cast(child, dt, timeZoneId, _) => - val childExpr = exprToProtoInternal(child, inputs) - castToProto(timeZoneId, dt, childExpr) + val (childExpr, info) = exprToProtoInternal(child, inputs) + if (childExpr.isDefined) { + castToProto(timeZoneId, dt, childExpr) + } else { + unsupported("CAST", info) + } case add @ Add(left, right, _) if supportedDataType(left.dataType) => - val leftExpr = exprToProtoInternal(left, inputs) - val rightExpr = exprToProtoInternal(right, inputs) + val (leftExpr, leftInfo) = exprToProtoInternal(left, inputs) + val (rightExpr, rightInfo) = exprToProtoInternal(right, inputs) if (leftExpr.isDefined && rightExpr.isDefined) { val addBuilder = ExprOuterClass.Add.newBuilder() @@ -463,18 +532,20 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde { addBuilder.setReturnType(t) } - Some( - ExprOuterClass.Expr - .newBuilder() - .setAdd(addBuilder) - .build()) + ( + Some( + ExprOuterClass.Expr + .newBuilder() + .setAdd(addBuilder) + .build()), + CometExplainInfo.none) } else { - None + unsupported("ADD", Seq(leftInfo, rightInfo)) } case sub @ Subtract(left, right, _) if supportedDataType(left.dataType) => - val leftExpr = exprToProtoInternal(left, inputs) - val rightExpr = exprToProtoInternal(right, inputs) + val (leftExpr, leftInfo) = exprToProtoInternal(left, inputs) + val (rightExpr, rightInfo) = exprToProtoInternal(right, inputs) if (leftExpr.isDefined && rightExpr.isDefined) { val builder = ExprOuterClass.Subtract.newBuilder() @@ -485,19 +556,21 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde { builder.setReturnType(t) } - Some( - ExprOuterClass.Expr - .newBuilder() - .setSubtract(builder) - .build()) + ( + Some( + ExprOuterClass.Expr + .newBuilder() + .setSubtract(builder) + .build()), + CometExplainInfo.none) } else { - None + unsupported("SUBTRACT", Seq(leftInfo, rightInfo)) } case mul @ Multiply(left, right, _) if supportedDataType(left.dataType) && !decimalBeforeSpark34(left.dataType) => - val leftExpr = exprToProtoInternal(left, inputs) - val rightExpr = exprToProtoInternal(right, inputs) + val (leftExpr, leftInfo) = exprToProtoInternal(left, inputs) + val (rightExpr, rightInfo) = exprToProtoInternal(right, inputs) if (leftExpr.isDefined && rightExpr.isDefined) { val builder = ExprOuterClass.Multiply.newBuilder() @@ -508,22 +581,25 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde { builder.setReturnType(t) } - Some( - ExprOuterClass.Expr - .newBuilder() - .setMultiply(builder) - .build()) + ( + Some( + ExprOuterClass.Expr + .newBuilder() + .setMultiply(builder) + .build()), + CometExplainInfo.none) } else { - None + unsupported("MULTIPLY", Seq(leftInfo, rightInfo)) } case div @ Divide(left, right, _) if supportedDataType(left.dataType) && !decimalBeforeSpark34(left.dataType) => - val leftExpr = exprToProtoInternal(left, inputs) + val (leftExpr, leftInfo) = exprToProtoInternal(left, inputs) // Datafusion now throws an exception for dividing by zero // See https://github.com/apache/arrow-datafusion/pull/6792 // For now, use NullIf to swap zeros with nulls. - val rightExpr = exprToProtoInternal(nullIfWhenPrimitive(right), inputs) + val (rightExpr, rightInfo) = + exprToProtoInternal(nullIfWhenPrimitive(right), inputs) if (leftExpr.isDefined && rightExpr.isDefined) { val builder = ExprOuterClass.Divide.newBuilder() @@ -534,19 +610,22 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde { builder.setReturnType(t) } - Some( - ExprOuterClass.Expr - .newBuilder() - .setDivide(builder) - .build()) + ( + Some( + ExprOuterClass.Expr + .newBuilder() + .setDivide(builder) + .build()), + CometExplainInfo.none) } else { - None + unsupported("DIVIDE", Seq(leftInfo, rightInfo)) } case rem @ Remainder(left, right, _) if supportedDataType(left.dataType) && !decimalBeforeSpark34(left.dataType) => - val leftExpr = exprToProtoInternal(left, inputs) - val rightExpr = exprToProtoInternal(nullIfWhenPrimitive(right), inputs) + val (leftExpr, leftInfo) = exprToProtoInternal(left, inputs) + val (rightExpr, rightInfo) = + exprToProtoInternal(nullIfWhenPrimitive(right), inputs) if (leftExpr.isDefined && rightExpr.isDefined) { val builder = ExprOuterClass.Remainder.newBuilder() @@ -557,157 +636,175 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde { builder.setReturnType(t) } - Some( - ExprOuterClass.Expr - .newBuilder() - .setRemainder(builder) - .build()) + ( + Some( + ExprOuterClass.Expr + .newBuilder() + .setRemainder(builder) + .build()), + CometExplainInfo.none) } else { - None + unsupported("REMAINDER", Seq(leftInfo, rightInfo)) } case EqualTo(left, right) => - val leftExpr = exprToProtoInternal(left, inputs) - val rightExpr = exprToProtoInternal(right, inputs) + val (leftExpr, leftInfo) = exprToProtoInternal(left, inputs) + val (rightExpr, rightInfo) = exprToProtoInternal(right, inputs) if (leftExpr.isDefined && rightExpr.isDefined) { val builder = ExprOuterClass.Equal.newBuilder() builder.setLeft(leftExpr.get) builder.setRight(rightExpr.get) - Some( - ExprOuterClass.Expr - .newBuilder() - .setEq(builder) - .build()) + ( + Some( + ExprOuterClass.Expr + .newBuilder() + .setEq(builder) + .build()), + CometExplainInfo.none) } else { - None + unsupported("EQUALTO", Seq(leftInfo, rightInfo)) } case Not(EqualTo(left, right)) => - val leftExpr = exprToProtoInternal(left, inputs) - val rightExpr = exprToProtoInternal(right, inputs) + val (leftExpr, leftInfo) = exprToProtoInternal(left, inputs) + val (rightExpr, rightInfo) = exprToProtoInternal(right, inputs) if (leftExpr.isDefined && rightExpr.isDefined) { val builder = ExprOuterClass.NotEqual.newBuilder() builder.setLeft(leftExpr.get) builder.setRight(rightExpr.get) - Some( - ExprOuterClass.Expr - .newBuilder() - .setNeq(builder) - .build()) + ( + Some( + ExprOuterClass.Expr + .newBuilder() + .setNeq(builder) + .build()), + CometExplainInfo.none) } else { - None + unsupported("NOTEQUALTO", Seq(leftInfo, rightInfo)) } case EqualNullSafe(left, right) => - val leftExpr = exprToProtoInternal(left, inputs) - val rightExpr = exprToProtoInternal(right, inputs) + val (leftExpr, leftInfo) = exprToProtoInternal(left, inputs) + val (rightExpr, rightInfo) = exprToProtoInternal(right, inputs) if (leftExpr.isDefined && rightExpr.isDefined) { val builder = ExprOuterClass.EqualNullSafe.newBuilder() builder.setLeft(leftExpr.get) builder.setRight(rightExpr.get) - Some( - ExprOuterClass.Expr - .newBuilder() - .setEqNullSafe(builder) - .build()) + ( + Some( + ExprOuterClass.Expr + .newBuilder() + .setEqNullSafe(builder) + .build()), + CometExplainInfo.none) } else { - None + unsupported("EQUALNULLSAFE", Seq(leftInfo, rightInfo)) } case Not(EqualNullSafe(left, right)) => - val leftExpr = exprToProtoInternal(left, inputs) - val rightExpr = exprToProtoInternal(right, inputs) + val (leftExpr, leftInfo) = exprToProtoInternal(left, inputs) + val (rightExpr, rightInfo) = exprToProtoInternal(right, inputs) if (leftExpr.isDefined && rightExpr.isDefined) { val builder = ExprOuterClass.NotEqualNullSafe.newBuilder() builder.setLeft(leftExpr.get) builder.setRight(rightExpr.get) - Some( - ExprOuterClass.Expr - .newBuilder() - .setNeqNullSafe(builder) - .build()) + ( + Some( + ExprOuterClass.Expr + .newBuilder() + .setNeqNullSafe(builder) + .build()), + CometExplainInfo.none) } else { - None + unsupported("NOTEQUALNULLSAFE", Seq(leftInfo, rightInfo)) } case GreaterThan(left, right) => - val leftExpr = exprToProtoInternal(left, inputs) - val rightExpr = exprToProtoInternal(right, inputs) + val (leftExpr, leftInfo) = exprToProtoInternal(left, inputs) + val (rightExpr, rightInfo) = exprToProtoInternal(right, inputs) if (leftExpr.isDefined && rightExpr.isDefined) { val builder = ExprOuterClass.GreaterThan.newBuilder() builder.setLeft(leftExpr.get) builder.setRight(rightExpr.get) - Some( - ExprOuterClass.Expr - .newBuilder() - .setGt(builder) - .build()) + ( + Some( + ExprOuterClass.Expr + .newBuilder() + .setGt(builder) + .build()), + CometExplainInfo.none) } else { - None + unsupported("GREATERTHAN", Seq(leftInfo, rightInfo)) } case GreaterThanOrEqual(left, right) => - val leftExpr = exprToProtoInternal(left, inputs) - val rightExpr = exprToProtoInternal(right, inputs) + val (leftExpr, leftInfo) = exprToProtoInternal(left, inputs) + val (rightExpr, rightInfo) = exprToProtoInternal(right, inputs) if (leftExpr.isDefined && rightExpr.isDefined) { val builder = ExprOuterClass.GreaterThanEqual.newBuilder() builder.setLeft(leftExpr.get) builder.setRight(rightExpr.get) - Some( - ExprOuterClass.Expr - .newBuilder() - .setGtEq(builder) - .build()) + ( + Some( + ExprOuterClass.Expr + .newBuilder() + .setGtEq(builder) + .build()), + CometExplainInfo.none) } else { - None + unsupported("GREATERTHANOREQUAL", Seq(leftInfo, rightInfo)) } case LessThan(left, right) => - val leftExpr = exprToProtoInternal(left, inputs) - val rightExpr = exprToProtoInternal(right, inputs) + val (leftExpr, leftInfo) = exprToProtoInternal(left, inputs) + val (rightExpr, rightInfo) = exprToProtoInternal(right, inputs) if (leftExpr.isDefined && rightExpr.isDefined) { val builder = ExprOuterClass.LessThan.newBuilder() builder.setLeft(leftExpr.get) builder.setRight(rightExpr.get) - Some( - ExprOuterClass.Expr - .newBuilder() - .setLt(builder) - .build()) + ( + Some( + ExprOuterClass.Expr + .newBuilder() + .setLt(builder) + .build()), + CometExplainInfo.none) } else { - None + unsupported("LESSTHAN", Seq(leftInfo, rightInfo)) } case LessThanOrEqual(left, right) => - val leftExpr = exprToProtoInternal(left, inputs) - val rightExpr = exprToProtoInternal(right, inputs) + val (leftExpr, leftInfo) = exprToProtoInternal(left, inputs) + val (rightExpr, rightInfo) = exprToProtoInternal(right, inputs) if (leftExpr.isDefined && rightExpr.isDefined) { val builder = ExprOuterClass.LessThanEqual.newBuilder() builder.setLeft(leftExpr.get) builder.setRight(rightExpr.get) - Some( - ExprOuterClass.Expr - .newBuilder() - .setLtEq(builder) - .build()) + ( + Some( + ExprOuterClass.Expr + .newBuilder() + .setLtEq(builder) + .build()), + CometExplainInfo.none) } else { - None + unsupported("LESSTHANOREQUAL", Seq(leftInfo, rightInfo)) } case Literal(value, dataType) if supportedDataType(dataType) => @@ -748,17 +845,19 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde { if (dt.isDefined) { exprBuilder.setDatatype(dt.get) - Some( - ExprOuterClass.Expr - .newBuilder() - .setLiteral(exprBuilder) - .build()) + ( + Some( + ExprOuterClass.Expr + .newBuilder() + .setLiteral(exprBuilder) + .build()), + CometExplainInfo.none) } else { - None + unsupported("LITERAL", CometExplainInfo(s"Unsupported datatype $dataType")) } case Substring(str, Literal(pos, _), Literal(len, _)) => - val strExpr = exprToProtoInternal(str, inputs) + val (strExpr, info) = exprToProtoInternal(str, inputs) if (strExpr.isDefined) { val builder = ExprOuterClass.Substring.newBuilder() @@ -766,125 +865,137 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde { builder.setStart(pos.asInstanceOf[Int]) builder.setLen(len.asInstanceOf[Int]) - Some( - ExprOuterClass.Expr - .newBuilder() - .setSubstring(builder) - .build()) + ( + Some( + ExprOuterClass.Expr + .newBuilder() + .setSubstring(builder) + .build()), + CometExplainInfo.none) } else { - None + unsupported("SUBSTRING", info) } case Like(left, right, _) => // TODO escapeChar - val leftExpr = exprToProtoInternal(left, inputs) - val rightExpr = exprToProtoInternal(right, inputs) + val (leftExpr, leftInfo) = exprToProtoInternal(left, inputs) + val (rightExpr, rightInfo) = exprToProtoInternal(right, inputs) if (leftExpr.isDefined && rightExpr.isDefined) { val builder = ExprOuterClass.Like.newBuilder() builder.setLeft(leftExpr.get) builder.setRight(rightExpr.get) - Some( - ExprOuterClass.Expr - .newBuilder() - .setLike(builder) - .build()) + ( + Some( + ExprOuterClass.Expr + .newBuilder() + .setLike(builder) + .build()), + CometExplainInfo.none) } else { - None + unsupported("LIKE", Seq(leftInfo, rightInfo)) } // TODO waiting for arrow-rs update - // case RLike(left, right) => - // val leftExpr = exprToProtoInternal(left, inputs) - // val rightExpr = exprToProtoInternal(right, inputs) - // - // if (leftExpr.isDefined && rightExpr.isDefined) { - // val builder = ExprOuterClass.RLike.newBuilder() - // builder.setLeft(leftExpr.get) - // builder.setRight(rightExpr.get) - // - // Some( - // ExprOuterClass.Expr - // .newBuilder() - // .setRlike(builder) - // .build()) - // } else { - // None - // } +// case RLike(left, right) => +// val leftExpr = exprToProtoInternal(left, inputs) +// val rightExpr = exprToProtoInternal(right, inputs) +// +// if (leftExpr.isDefined && rightExpr.isDefined) { +// val builder = ExprOuterClass.RLike.newBuilder() +// builder.setLeft(leftExpr.get) +// builder.setRight(rightExpr.get) +// +// Some( +// ExprOuterClass.Expr +// .newBuilder() +// .setRlike(builder) +// .build()) +// } else { +// None +// } case StartsWith(left, right) => - val leftExpr = exprToProtoInternal(left, inputs) - val rightExpr = exprToProtoInternal(right, inputs) + val (leftExpr, leftInfo) = exprToProtoInternal(left, inputs) + val (rightExpr, rightInfo) = exprToProtoInternal(right, inputs) if (leftExpr.isDefined && rightExpr.isDefined) { val builder = ExprOuterClass.StartsWith.newBuilder() builder.setLeft(leftExpr.get) builder.setRight(rightExpr.get) - Some( - ExprOuterClass.Expr - .newBuilder() - .setStartsWith(builder) - .build()) + ( + Some( + ExprOuterClass.Expr + .newBuilder() + .setStartsWith(builder) + .build()), + CometExplainInfo.none) } else { - None + unsupported("STARTSWITH", Seq(leftInfo, rightInfo)) } case EndsWith(left, right) => - val leftExpr = exprToProtoInternal(left, inputs) - val rightExpr = exprToProtoInternal(right, inputs) + val (leftExpr, leftInfo) = exprToProtoInternal(left, inputs) + val (rightExpr, rightInfo) = exprToProtoInternal(right, inputs) if (leftExpr.isDefined && rightExpr.isDefined) { val builder = ExprOuterClass.EndsWith.newBuilder() builder.setLeft(leftExpr.get) builder.setRight(rightExpr.get) - Some( - ExprOuterClass.Expr - .newBuilder() - .setEndsWith(builder) - .build()) + ( + Some( + ExprOuterClass.Expr + .newBuilder() + .setEndsWith(builder) + .build()), + CometExplainInfo.none) } else { - None + unsupported("ENDWITH", Seq(leftInfo, rightInfo)) } case Contains(left, right) => - val leftExpr = exprToProtoInternal(left, inputs) - val rightExpr = exprToProtoInternal(right, inputs) + val (leftExpr, leftInfo) = exprToProtoInternal(left, inputs) + val (rightExpr, rightInfo) = exprToProtoInternal(right, inputs) if (leftExpr.isDefined && rightExpr.isDefined) { val builder = ExprOuterClass.Contains.newBuilder() builder.setLeft(leftExpr.get) builder.setRight(rightExpr.get) - Some( - ExprOuterClass.Expr - .newBuilder() - .setContains(builder) - .build()) + ( + Some( + ExprOuterClass.Expr + .newBuilder() + .setContains(builder) + .build()), + CometExplainInfo.none) } else { - None + unsupported("CONTAINS", Seq(leftInfo, rightInfo)) } case StringSpace(child) => - val childExpr = exprToProtoInternal(child, inputs) + val (childExpr, info) = exprToProtoInternal(child, inputs) if (childExpr.isDefined) { val builder = ExprOuterClass.StringSpace.newBuilder() builder.setChild(childExpr.get) - Some( - ExprOuterClass.Expr - .newBuilder() - .setStringSpace(builder) - .build()) + ( + Some( + ExprOuterClass.Expr + .newBuilder() + .setStringSpace(builder) + .build()), + CometExplainInfo.none) } else { - None + unsupported("STRINGSPACE", info) } case Hour(child, timeZoneId) => - val childExpr = exprToProtoInternal(child, inputs) + val (childExpr, info) = exprToProtoInternal(child, inputs) if (childExpr.isDefined) { val builder = ExprOuterClass.Hour.newBuilder() @@ -893,17 +1004,19 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde { val timeZone = timeZoneId.getOrElse("UTC") builder.setTimezone(timeZone) - Some( - ExprOuterClass.Expr - .newBuilder() - .setHour(builder) - .build()) + ( + Some( + ExprOuterClass.Expr + .newBuilder() + .setHour(builder) + .build()), + CometExplainInfo.none) } else { - None + unsupported("HOUR", info) } case Minute(child, timeZoneId) => - val childExpr = exprToProtoInternal(child, inputs) + val (childExpr, info) = exprToProtoInternal(child, inputs) if (childExpr.isDefined) { val builder = ExprOuterClass.Minute.newBuilder() @@ -912,36 +1025,40 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde { val timeZone = timeZoneId.getOrElse("UTC") builder.setTimezone(timeZone) - Some( - ExprOuterClass.Expr - .newBuilder() - .setMinute(builder) - .build()) + ( + Some( + ExprOuterClass.Expr + .newBuilder() + .setMinute(builder) + .build()), + CometExplainInfo.none) } else { - None + unsupported("MINUTE", info) } case TruncDate(child, format) => - val childExpr = exprToProtoInternal(child, inputs) - val formatExpr = exprToProtoInternal(format, inputs) + val (childExpr, info) = exprToProtoInternal(child, inputs) + val (formatExpr, formatInfo) = exprToProtoInternal(format, inputs) if (childExpr.isDefined && formatExpr.isDefined) { val builder = ExprOuterClass.TruncDate.newBuilder() builder.setChild(childExpr.get) builder.setFormat(formatExpr.get) - Some( - ExprOuterClass.Expr - .newBuilder() - .setTruncDate(builder) - .build()) + ( + Some( + ExprOuterClass.Expr + .newBuilder() + .setTruncDate(builder) + .build()), + CometExplainInfo.none) } else { - None + unsupported("TRUNCDATE", Seq(info, formatInfo)) } case TruncTimestamp(format, child, timeZoneId) => - val childExpr = exprToProtoInternal(child, inputs) - val formatExpr = exprToProtoInternal(format, inputs) + val (childExpr, info) = exprToProtoInternal(child, inputs) + val (formatExpr, formatInfo) = exprToProtoInternal(format, inputs) if (childExpr.isDefined && formatExpr.isDefined) { val builder = ExprOuterClass.TruncTimestamp.newBuilder() @@ -951,17 +1068,19 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde { val timeZone = timeZoneId.getOrElse("UTC") builder.setTimezone(timeZone) - Some( - ExprOuterClass.Expr - .newBuilder() - .setTruncTimestamp(builder) - .build()) + ( + Some( + ExprOuterClass.Expr + .newBuilder() + .setTruncTimestamp(builder) + .build()), + CometExplainInfo.none) } else { - None + unsupported("TRUNCTIMESTAMP", Seq(info, formatInfo)) } case Second(child, timeZoneId) => - val childExpr = exprToProtoInternal(child, inputs) + val (childExpr, info) = exprToProtoInternal(child, inputs) if (childExpr.isDefined) { val builder = ExprOuterClass.Second.newBuilder() @@ -970,18 +1089,20 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde { val timeZone = timeZoneId.getOrElse("UTC") builder.setTimezone(timeZone) - Some( - ExprOuterClass.Expr - .newBuilder() - .setSecond(builder) - .build()) + ( + Some( + ExprOuterClass.Expr + .newBuilder() + .setSecond(builder) + .build()), + CometExplainInfo.none) } else { - None + unsupported("SECOND", info) } case Year(child) => - val periodType = exprToProtoInternal(Literal("year"), inputs) - val childExpr = exprToProtoInternal(child, inputs) + val (periodType, _) = exprToProtoInternal(Literal("year"), inputs) + val (childExpr, info) = exprToProtoInternal(child, inputs) scalarExprToProto("datepart", Seq(periodType, childExpr): _*) .map(e => { Expr @@ -993,42 +1114,49 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde { .setDatatype(serializeDataType(IntegerType).get) .build()) .build() - }) + }) match { + case Some(e) => (Some(e), CometExplainInfo.none) + case None => unsupported("YEAR", info) + } case IsNull(child) => - val childExpr = exprToProtoInternal(child, inputs) + val (childExpr, info) = exprToProtoInternal(child, inputs) if (childExpr.isDefined) { val castBuilder = ExprOuterClass.IsNull.newBuilder() castBuilder.setChild(childExpr.get) - Some( - ExprOuterClass.Expr - .newBuilder() - .setIsNull(castBuilder) - .build()) + ( + Some( + ExprOuterClass.Expr + .newBuilder() + .setIsNull(castBuilder) + .build()), + CometExplainInfo.none) } else { - None + unsupported("ISNULL", info) } case IsNotNull(child) => - val childExpr = exprToProtoInternal(child, inputs) + val (childExpr, info) = exprToProtoInternal(child, inputs) if (childExpr.isDefined) { val castBuilder = ExprOuterClass.IsNotNull.newBuilder() castBuilder.setChild(childExpr.get) - Some( - ExprOuterClass.Expr - .newBuilder() - .setIsNotNull(castBuilder) - .build()) + ( + Some( + ExprOuterClass.Expr + .newBuilder() + .setIsNotNull(castBuilder) + .build()), + CometExplainInfo.none) } else { - None + unsupported("ISNOTNULL", info) } case SortOrder(child, direction, nullOrdering, _) => - val childExpr = exprToProtoInternal(child, inputs) + val (childExpr, info) = exprToProtoInternal(child, inputs) if (childExpr.isDefined) { val sortOrderBuilder = ExprOuterClass.SortOrder.newBuilder() @@ -1044,49 +1172,55 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde { case NullsLast => sortOrderBuilder.setNullOrderingValue(1) } - Some( - ExprOuterClass.Expr - .newBuilder() - .setSortOrder(sortOrderBuilder) - .build()) + ( + Some( + ExprOuterClass.Expr + .newBuilder() + .setSortOrder(sortOrderBuilder) + .build()), + CometExplainInfo.none) } else { - None + unsupported("SORTORDER", info) } case And(left, right) => - val leftExpr = exprToProtoInternal(left, inputs) - val rightExpr = exprToProtoInternal(right, inputs) + val (leftExpr, leftInfo) = exprToProtoInternal(left, inputs) + val (rightExpr, rightInfo) = exprToProtoInternal(right, inputs) if (leftExpr.isDefined && rightExpr.isDefined) { val builder = ExprOuterClass.And.newBuilder() builder.setLeft(leftExpr.get) builder.setRight(rightExpr.get) - Some( - ExprOuterClass.Expr - .newBuilder() - .setAnd(builder) - .build()) + ( + Some( + ExprOuterClass.Expr + .newBuilder() + .setAnd(builder) + .build()), + CometExplainInfo.none) } else { - None + unsupported("AND", Seq(leftInfo, rightInfo)) } case Or(left, right) => - val leftExpr = exprToProtoInternal(left, inputs) - val rightExpr = exprToProtoInternal(right, inputs) + val (leftExpr, leftInfo) = exprToProtoInternal(left, inputs) + val (rightExpr, rightInfo) = exprToProtoInternal(right, inputs) if (leftExpr.isDefined && rightExpr.isDefined) { val builder = ExprOuterClass.Or.newBuilder() builder.setLeft(leftExpr.get) builder.setRight(rightExpr.get) - Some( - ExprOuterClass.Expr - .newBuilder() - .setOr(builder) - .build()) + ( + Some( + ExprOuterClass.Expr + .newBuilder() + .setOr(builder) + .build()), + CometExplainInfo.none) } else { - None + unsupported("OR", Seq(leftInfo, rightInfo)) } case UnaryExpression(child) if expr.prettyName == "promote_precision" => @@ -1095,7 +1229,7 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde { exprToProtoInternal(child, inputs) case CheckOverflow(child, dt, nullOnOverflow) => - val childExpr = exprToProtoInternal(child, inputs) + val (childExpr, info) = exprToProtoInternal(child, inputs) if (childExpr.isDefined) { val builder = ExprOuterClass.CheckOverflow.newBuilder() @@ -1106,13 +1240,15 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde { val dataType = serializeDataType(dt) builder.setDatatype(dataType.get) - Some( - ExprOuterClass.Expr - .newBuilder() - .setCheckOverflow(builder) - .build()) + ( + Some( + ExprOuterClass.Expr + .newBuilder() + .setCheckOverflow(builder) + .build()), + CometExplainInfo.none) } else { - None + unsupported("CHECKOVERFLOW", info) } case attr: AttributeReference => @@ -1129,101 +1265,143 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde { .setDatatype(dataType.get) .build() - Some( - ExprOuterClass.Expr - .newBuilder() - .setBound(boundExpr) - .build()) + ( + Some( + ExprOuterClass.Expr + .newBuilder() + .setBound(boundExpr) + .build()), + CometExplainInfo.none) } else { val unboundRef = ExprOuterClass.UnboundReference .newBuilder() - .setName(attr.name) .setDatatype(dataType.get) .build() - Some( - ExprOuterClass.Expr - .newBuilder() - .setUnbound(unboundRef) - .build()) + ( + Some( + ExprOuterClass.Expr + .newBuilder() + .setUnbound(unboundRef) + .build()), + CometExplainInfo.none) } } else { - None + unsupported("ATTRREF", CometExplainInfo(s"unsupported datatype: ${attr.dataType}")) } case Abs(child, _) => - exprToProtoInternal(child, inputs).map(childExpr => { + val (childExpr, info) = exprToProtoInternal(child, inputs) + if (childExpr.isDefined) { val abs = ExprOuterClass.Abs .newBuilder() - .setChild(childExpr) + .setChild(childExpr.get) .build() - Expr.newBuilder().setAbs(abs).build() - }) + (Some(Expr.newBuilder().setAbs(abs).build()), CometExplainInfo.none) + } else { + unsupported("ABS", info) + } case Acos(child) => - val childExpr = exprToProtoInternal(child, inputs) - scalarExprToProto("acos", childExpr) + val (childExpr, info) = exprToProtoInternal(child, inputs) + scalarExprToProto("acos", childExpr) match { + case Some(e) => (Some(e), CometExplainInfo.none) + case None => unsupported("ACOS", info) + } case Asin(child) => - val childExpr = exprToProtoInternal(child, inputs) - scalarExprToProto("asin", childExpr) + val (childExpr, info) = exprToProtoInternal(child, inputs) + scalarExprToProto("asin", childExpr) match { + case Some(e) => (Some(e), CometExplainInfo.none) + case None => unsupported("ASIN", info) + } case Atan(child) => - val childExpr = exprToProtoInternal(child, inputs) - scalarExprToProto("atan", childExpr) + val (childExpr, info) = exprToProtoInternal(child, inputs) + scalarExprToProto("atan", childExpr) match { + case Some(e) => (Some(e), CometExplainInfo.none) + case None => unsupported("ATAN", info) + } case Atan2(left, right) => - val leftExpr = exprToProtoInternal(left, inputs) - val rightExpr = exprToProtoInternal(right, inputs) - scalarExprToProto("atan2", leftExpr, rightExpr) + val (leftExpr, leftInfo) = exprToProtoInternal(left, inputs) + val (rightExpr, rightInfo) = exprToProtoInternal(right, inputs) + scalarExprToProto("atan2", leftExpr, rightExpr) match { + case Some(e) => (Some(e), CometExplainInfo.none) + case None => unsupported("ATAN2", Seq(leftInfo, rightInfo)) + } case e @ Ceil(child) => - val childExpr = exprToProtoInternal(child, inputs) + val (childExpr, info) = exprToProtoInternal(child, inputs) child.dataType match { case t: DecimalType if t.scale == 0 => // zero scale is no-op - childExpr + (childExpr, CometExplainInfo.none) case t: DecimalType if t.scale < 0 => // Spark disallows negative scale SPARK-30252 - None + unsupported("CEIL", Seq(info, CometExplainInfo("Decimal type has negative scale"))) case _ => - scalarExprToProtoWithReturnType("ceil", e.dataType, childExpr) + scalarExprToProtoWithReturnType("ceil", e.dataType, childExpr) match { + case Some(e) => (Some(e), CometExplainInfo.none) + case None => unsupported("CEIL", info) + } } case Cos(child) => - val childExpr = exprToProtoInternal(child, inputs) - scalarExprToProto("cos", childExpr) + val (childExpr, info) = exprToProtoInternal(child, inputs) + scalarExprToProto("cos", childExpr) match { + case Some(e) => (Some(e), CometExplainInfo.none) + case None => unsupported("COS", info) + } case Exp(child) => - val childExpr = exprToProtoInternal(child, inputs) - scalarExprToProto("exp", childExpr) + val (childExpr, info) = exprToProtoInternal(child, inputs) + scalarExprToProto("exp", childExpr) match { + case Some(e) => (Some(e), CometExplainInfo.none) + case None => unsupported("EXP", info) + } case e @ Floor(child) => - val childExpr = exprToProtoInternal(child, inputs) + val (childExpr, info) = exprToProtoInternal(child, inputs) child.dataType match { case t: DecimalType if t.scale == 0 => // zero scale is no-op - childExpr + (childExpr, CometExplainInfo.none) case t: DecimalType if t.scale < 0 => // Spark disallows negative scale SPARK-30252 - None + unsupported("FLOOR", Seq(info, CometExplainInfo("Decimal type has negative scale"))) case _ => - scalarExprToProtoWithReturnType("floor", e.dataType, childExpr) + scalarExprToProtoWithReturnType("floor", e.dataType, childExpr) match { + case Some(e) => (Some(e), CometExplainInfo.none) + case None => unsupported("FLOOR", info) + } } case Log(child) => - val childExpr = exprToProtoInternal(child, inputs) - scalarExprToProto("ln", childExpr) + val (childExpr, info) = exprToProtoInternal(child, inputs) + scalarExprToProto("ln", childExpr) match { + case Some(e) => (Some(e), CometExplainInfo.none) + case None => unsupported("LN", info) + } case Log10(child) => - val childExpr = exprToProtoInternal(child, inputs) - scalarExprToProto("log10", childExpr) + val (childExpr, info) = exprToProtoInternal(child, inputs) + scalarExprToProto("log10", childExpr) match { + case Some(e) => (Some(e), CometExplainInfo.none) + case None => unsupported("LOG10", info) + } case Log2(child) => - val childExpr = exprToProtoInternal(child, inputs) - scalarExprToProto("log2", childExpr) + val (childExpr, info) = exprToProtoInternal(child, inputs) + scalarExprToProto("log2", childExpr) match { + case Some(e) => (Some(e), CometExplainInfo.none) + case None => unsupported("LOG2", info) + } case Pow(left, right) => - val leftExpr = exprToProtoInternal(left, inputs) - val rightExpr = exprToProtoInternal(right, inputs) - scalarExprToProto("pow", leftExpr, rightExpr) + val (leftExpr, leftInfo) = exprToProtoInternal(left, inputs) + val (rightExpr, rightInfo) = exprToProtoInternal(right, inputs) + scalarExprToProto("pow", leftExpr, rightExpr) match { + case Some(e) => (Some(e), CometExplainInfo.none) + case None => unsupported("ACOS", Seq(leftInfo, rightInfo)) + } // round function for Spark 3.2 does not allow negative round target scale. In addition, // it has different result precision/scale for decimals. Supporting only 3.3 and above. @@ -1232,14 +1410,22 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde { val scaleV: Any = r.scale.eval(EmptyRow) val _scale: Int = scaleV.asInstanceOf[Int] - lazy val childExpr = exprToProtoInternal(r.child, inputs) + lazy val (childExpr, info) = exprToProtoInternal(r.child, inputs) r.child.dataType match { case t: DecimalType if t.scale < 0 => // Spark disallows negative scale SPARK-30252 - None + unsupported("ROUND", Seq(info, CometExplainInfo("Decimal type has negative scale"))) case _ if scaleV == null => - exprToProtoInternal(Literal(null), inputs) + val (childScaleIsNull, infoScaleIsNull) = + exprToProtoInternal(Literal(null), inputs) + childScaleIsNull match { + case Some(e) => (Some(e), CometExplainInfo.none) + case None => unsupported("ROUND", Seq(info, infoScaleIsNull)) + } case _: ByteType | ShortType | IntegerType | LongType if _scale >= 0 => - childExpr // _scale(I.e. decimal place) >= 0 is a no-op for integer types in Spark + ( + childExpr, + CometExplainInfo.none + ) // _scale(I.e. decimal place) >= 0 is a no-op for integer types in Spark case _: FloatType | DoubleType => // We cannot properly match with the Spark behavior for floating-point numbers. // Spark uses BigDecimal for rounding float/double, and BigDecimal fist converts a @@ -1255,133 +1441,204 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde { // I.e. 6.13171162472835E18 == 6.1317116247283497E18. However, toString() does not. // That results in round(6.1317116247283497E18, -5) == 6.1317116247282995E18 instead // of 6.1317116247283999E18. - None + unsupported( + "ROUND", + CometExplainInfo("Comet does not support Spark's BigDecimal rounding")) case _ => // `scale` must be Int64 type in DataFusion - val scaleExpr = exprToProtoInternal(Literal(_scale.toLong, LongType), inputs) - scalarExprToProtoWithReturnType("round", r.dataType, childExpr, scaleExpr) + val (scaleExpr, info) = + exprToProtoInternal(Literal(_scale.toLong, LongType), inputs) + scalarExprToProtoWithReturnType("round", r.dataType, childExpr, scaleExpr) match { + case Some(e) => (Some(e), CometExplainInfo.none) + case None => unsupported("ROUND", info) + } } case Signum(child) => - val childExpr = exprToProtoInternal(child, inputs) - scalarExprToProto("signum", childExpr) + val (childExpr, info) = exprToProtoInternal(child, inputs) + scalarExprToProto("signum", childExpr) match { + case Some(e) => (Some(e), CometExplainInfo.none) + case None => unsupported("SIGNUM", info) + } case Sin(child) => - val childExpr = exprToProtoInternal(child, inputs) - scalarExprToProto("sin", childExpr) + val (childExpr, info) = exprToProtoInternal(child, inputs) + scalarExprToProto("sin", childExpr) match { + case Some(e) => (Some(e), CometExplainInfo.none) + case None => unsupported("SIN", info) + } case Sqrt(child) => - val childExpr = exprToProtoInternal(child, inputs) - scalarExprToProto("sqrt", childExpr) + val (childExpr, info) = exprToProtoInternal(child, inputs) + scalarExprToProto("sqrt", childExpr) match { + case Some(e) => (Some(e), CometExplainInfo.none) + case None => unsupported("SQRT", info) + } case Tan(child) => - val childExpr = exprToProtoInternal(child, inputs) - scalarExprToProto("tan", childExpr) + val (childExpr, info) = exprToProtoInternal(child, inputs) + scalarExprToProto("tan", childExpr) match { + case Some(e) => (Some(e), CometExplainInfo.none) + case None => unsupported("TAN", info) + } case Ascii(child) => - val childExpr = exprToProtoInternal(Cast(child, StringType), inputs) - scalarExprToProto("ascii", childExpr) + val (childExpr, info) = exprToProtoInternal(Cast(child, StringType), inputs) + scalarExprToProto("ascii", childExpr) match { + case Some(e) => (Some(e), CometExplainInfo.none) + case None => unsupported("ASCII", info) + } case BitLength(child) => - val childExpr = exprToProtoInternal(Cast(child, StringType), inputs) - scalarExprToProto("bit_length", childExpr) + val (childExpr, info) = exprToProtoInternal(Cast(child, StringType), inputs) + scalarExprToProto("bit_length", childExpr) match { + case Some(e) => (Some(e), CometExplainInfo.none) + case None => unsupported("BIT_LENGTH", info) + } case If(predicate, trueValue, falseValue) => - val predicateExpr = exprToProtoInternal(predicate, inputs) - val trueExpr = exprToProtoInternal(trueValue, inputs) - val falseExpr = exprToProtoInternal(falseValue, inputs) + val (predicateExpr, predicateInfo) = exprToProtoInternal(predicate, inputs) + val (trueExpr, trueInfo) = exprToProtoInternal(trueValue, inputs) + val (falseExpr, falseInfo) = exprToProtoInternal(falseValue, inputs) if (predicateExpr.isDefined && trueExpr.isDefined && falseExpr.isDefined) { val builder = ExprOuterClass.IfExpr.newBuilder() builder.setIfExpr(predicateExpr.get) builder.setTrueExpr(trueExpr.get) builder.setFalseExpr(falseExpr.get) - Some( - ExprOuterClass.Expr - .newBuilder() - .setIf(builder) - .build()) + ( + Some( + ExprOuterClass.Expr + .newBuilder() + .setIf(builder) + .build()), + CometExplainInfo.none) } else { - None + unsupported("IF", Seq(predicateInfo, trueInfo, falseInfo)) } case CaseWhen(branches, elseValue) => - val whenSeq = branches.map(elements => exprToProtoInternal(elements._1, inputs)) - val thenSeq = branches.map(elements => exprToProtoInternal(elements._2, inputs)) + val (whenSeq, whenInfo) = + branches.map(elements => exprToProtoInternal(elements._1, inputs)).unzip + val (thenSeq, thenInfo) = + branches.map(elements => exprToProtoInternal(elements._2, inputs)).unzip assert(whenSeq.length == thenSeq.length) if (whenSeq.forall(_.isDefined) && thenSeq.forall(_.isDefined)) { val builder = ExprOuterClass.CaseWhen.newBuilder() builder.addAllWhen(whenSeq.map(_.get).asJava) builder.addAllThen(thenSeq.map(_.get).asJava) if (elseValue.isDefined) { - val elseValueExpr = exprToProtoInternal(elseValue.get, inputs) + val (elseValueExpr, elseValueInfo) = + exprToProtoInternal(elseValue.get, inputs) if (elseValueExpr.isDefined) { builder.setElseExpr(elseValueExpr.get) } else { - return None + return unsupported("CASE", elseValueInfo) } } - Some( - ExprOuterClass.Expr - .newBuilder() - .setCaseWhen(builder) - .build()) + ( + Some( + ExprOuterClass.Expr + .newBuilder() + .setCaseWhen(builder) + .build()), + CometExplainInfo.none) } else { - None + unsupported("CASE", whenInfo ++ thenInfo) } - case ConcatWs(children) => val exprs = children.map(e => exprToProtoInternal(Cast(e, StringType), inputs)) - scalarExprToProto("concat_ws", exprs: _*) + scalarExprToProto("concat_ws", exprs.map(_._1): _*) match { + case Some(e) => (Some(e), CometExplainInfo.none) + case None => unsupported("CONCAT_WS", exprs.map(_._2)) + } case Chr(child) => - val childExpr = exprToProtoInternal(child, inputs) - scalarExprToProto("chr", childExpr) + val (childExpr, info) = exprToProtoInternal(child, inputs) + scalarExprToProto("chr", childExpr) match { + case Some(e) => (Some(e), CometExplainInfo.none) + case None => unsupported("CHR", info) + } case InitCap(child) => - val childExpr = exprToProtoInternal(Cast(child, StringType), inputs) - scalarExprToProto("initcap", childExpr) + val (childExpr, info) = exprToProtoInternal(Cast(child, StringType), inputs) + scalarExprToProto("initcap", childExpr) match { + case Some(e) => (Some(e), CometExplainInfo.none) + case None => unsupported("INITCAP", info) + } case Length(child) => - val childExpr = exprToProtoInternal(Cast(child, StringType), inputs) - scalarExprToProto("length", childExpr) + val (childExpr, info) = exprToProtoInternal(Cast(child, StringType), inputs) + scalarExprToProto("length", childExpr) match { + case Some(e) => (Some(e), CometExplainInfo.none) + case None => unsupported("LENGTH", info) + } case Lower(child) => - val childExpr = exprToProtoInternal(Cast(child, StringType), inputs) - scalarExprToProto("lower", childExpr) + val (childExpr, info) = exprToProtoInternal(Cast(child, StringType), inputs) + scalarExprToProto("lower", childExpr) match { + case Some(e) => (Some(e), CometExplainInfo.none) + case None => unsupported("LOWER", info) + } case Md5(child) => - val childExpr = exprToProtoInternal(Cast(child, StringType), inputs) - scalarExprToProto("md5", childExpr) + val (childExpr, info) = exprToProtoInternal(Cast(child, StringType), inputs) + scalarExprToProto("md5", childExpr) match { + case Some(e) => (Some(e), CometExplainInfo.none) + case None => unsupported("MD5", info) + } case OctetLength(child) => - val childExpr = exprToProtoInternal(Cast(child, StringType), inputs) - scalarExprToProto("octet_length", childExpr) + val (childExpr, info) = exprToProtoInternal(Cast(child, StringType), inputs) + scalarExprToProto("octet_length", childExpr) match { + case Some(e) => (Some(e), CometExplainInfo.none) + case None => unsupported("OCTET_LENGTH", info) + } case Reverse(child) => - val childExpr = exprToProtoInternal(Cast(child, StringType), inputs) - scalarExprToProto("reverse", childExpr) + val (childExpr, info) = exprToProtoInternal(Cast(child, StringType), inputs) + scalarExprToProto("reverse", childExpr) match { + case Some(e) => (Some(e), CometExplainInfo.none) + case None => unsupported("REVERSE", info) + } case StringInstr(str, substr) => - val leftExpr = exprToProtoInternal(Cast(str, StringType), inputs) - val rightExpr = exprToProtoInternal(Cast(substr, StringType), inputs) - scalarExprToProto("strpos", leftExpr, rightExpr) + val (leftExpr, leftInfo) = exprToProtoInternal(Cast(str, StringType), inputs) + val (rightExpr, rightInfo) = + exprToProtoInternal(Cast(substr, StringType), inputs) + scalarExprToProto("strpos", leftExpr, rightExpr) match { + case Some(e) => (Some(e), CometExplainInfo.none) + case None => unsupported("STRPOS", Seq(leftInfo, rightInfo)) + } case StringRepeat(str, times) => - val leftExpr = exprToProtoInternal(Cast(str, StringType), inputs) - val rightExpr = exprToProtoInternal(Cast(times, LongType), inputs) - scalarExprToProto("repeat", leftExpr, rightExpr) + val (leftExpr, leftInfo) = exprToProtoInternal(Cast(str, StringType), inputs) + val (rightExpr, rightInfo) = exprToProtoInternal(Cast(times, LongType), inputs) + scalarExprToProto("repeat", leftExpr, rightExpr) match { + case Some(e) => (Some(e), CometExplainInfo.none) + case None => unsupported("REPEAT", Seq(leftInfo, rightInfo)) + } case StringReplace(src, search, replace) => - val srcExpr = exprToProtoInternal(Cast(src, StringType), inputs) - val searchExpr = exprToProtoInternal(Cast(search, StringType), inputs) - val replaceExpr = exprToProtoInternal(Cast(replace, StringType), inputs) - scalarExprToProto("replace", srcExpr, searchExpr, replaceExpr) + val (srcExpr, srcInfo) = exprToProtoInternal(Cast(src, StringType), inputs) + val (searchExpr, searchInfo) = + exprToProtoInternal(Cast(search, StringType), inputs) + val (replaceExpr, replaceInfo) = + exprToProtoInternal(Cast(replace, StringType), inputs) + scalarExprToProto("replace", srcExpr, searchExpr, replaceExpr) match { + case Some(e) => (Some(e), CometExplainInfo.none) + case None => unsupported("REPLACE", Seq(srcInfo, searchInfo, replaceInfo)) + } case StringTranslate(src, matching, replace) => - val srcExpr = exprToProtoInternal(Cast(src, StringType), inputs) - val matchingExpr = exprToProtoInternal(Cast(matching, StringType), inputs) - val replaceExpr = exprToProtoInternal(Cast(replace, StringType), inputs) - scalarExprToProto("translate", srcExpr, matchingExpr, replaceExpr) + val (srcExpr, srcInfo) = exprToProtoInternal(Cast(src, StringType), inputs) + val (matchingExpr, matchingInfo) = + exprToProtoInternal(Cast(matching, StringType), inputs) + val (replaceExpr, replaceInfo) = + exprToProtoInternal(Cast(replace, StringType), inputs) + scalarExprToProto("translate", srcExpr, matchingExpr, replaceExpr) match { + case Some(e) => (Some(e), CometExplainInfo.none) + case None => unsupported("TRANSLATE", Seq(srcInfo, matchingInfo, replaceInfo)) + } case StringTrim(srcStr, trimStr) => trim(srcStr, trimStr, inputs, "trim") @@ -1396,82 +1653,93 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde { trim(srcStr, trimStr, inputs, "btrim") case Upper(child) => - val childExpr = exprToProtoInternal(Cast(child, StringType), inputs) - scalarExprToProto("upper", childExpr) + val (childExpr, info) = exprToProtoInternal(Cast(child, StringType), inputs) + scalarExprToProto("upper", childExpr) match { + case Some(e) => (Some(e), CometExplainInfo.none) + case None => unsupported("UPPER", info) + } case BitwiseAnd(left, right) => - val leftExpr = exprToProtoInternal(left, inputs) - val rightExpr = exprToProtoInternal(right, inputs) + val (leftExpr, leftInfo) = exprToProtoInternal(left, inputs) + val (rightExpr, rightInfo) = exprToProtoInternal(right, inputs) if (leftExpr.isDefined && rightExpr.isDefined) { val builder = ExprOuterClass.BitwiseAnd.newBuilder() builder.setLeft(leftExpr.get) builder.setRight(rightExpr.get) - Some( - ExprOuterClass.Expr - .newBuilder() - .setBitwiseAnd(builder) - .build()) + ( + Some( + ExprOuterClass.Expr + .newBuilder() + .setBitwiseAnd(builder) + .build()), + CometExplainInfo.none) } else { - None + unsupported("BITWISE_AND", Seq(leftInfo, rightInfo)) } case BitwiseNot(child) => - val childExpr = exprToProtoInternal(child, inputs) + val (childExpr, info) = exprToProtoInternal(child, inputs) if (childExpr.isDefined) { val builder = ExprOuterClass.BitwiseNot.newBuilder() builder.setChild(childExpr.get) - Some( - ExprOuterClass.Expr - .newBuilder() - .setBitwiseNot(builder) - .build()) + ( + Some( + ExprOuterClass.Expr + .newBuilder() + .setBitwiseNot(builder) + .build()), + CometExplainInfo.none) } else { - None + unsupported("BITWISE_NOT", info) } case BitwiseOr(left, right) => - val leftExpr = exprToProtoInternal(left, inputs) - val rightExpr = exprToProtoInternal(right, inputs) + val (leftExpr, leftInfo) = exprToProtoInternal(left, inputs) + val (rightExpr, rightInfo) = exprToProtoInternal(right, inputs) if (leftExpr.isDefined && rightExpr.isDefined) { val builder = ExprOuterClass.BitwiseOr.newBuilder() builder.setLeft(leftExpr.get) builder.setRight(rightExpr.get) - Some( - ExprOuterClass.Expr - .newBuilder() - .setBitwiseOr(builder) - .build()) + ( + Some( + ExprOuterClass.Expr + .newBuilder() + .setBitwiseOr(builder) + .build()), + CometExplainInfo.none) } else { - None + unsupported("BITWISE_OR", Seq(leftInfo, rightInfo)) } case BitwiseXor(left, right) => - val leftExpr = exprToProtoInternal(left, inputs) - val rightExpr = exprToProtoInternal(right, inputs) + val (leftExpr, leftInfo) = exprToProtoInternal(left, inputs) + val (rightExpr, rightInfo) = exprToProtoInternal(right, inputs) if (leftExpr.isDefined && rightExpr.isDefined) { val builder = ExprOuterClass.BitwiseXor.newBuilder() builder.setLeft(leftExpr.get) builder.setRight(rightExpr.get) - Some( - ExprOuterClass.Expr - .newBuilder() - .setBitwiseXor(builder) - .build()) + ( + Some( + ExprOuterClass.Expr + .newBuilder() + .setBitwiseXor(builder) + .build()), + CometExplainInfo.none) } else { - None + unsupported("BITWISE_XOR", Seq(leftInfo, rightInfo)) } case ShiftRight(left, right) => - val leftExpr = exprToProtoInternal(left, inputs) - val rightExpr = if (left.dataType == LongType) { + val (leftExpr, leftInfo) = exprToProtoInternal(left, inputs) + val (rightExpr, rightInfo) = if (left.dataType == LongType) { // DataFusion bitwise shift right expression requires // same data type between left and right side exprToProtoInternal(Cast(right, LongType), inputs) @@ -1484,18 +1752,20 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde { builder.setLeft(leftExpr.get) builder.setRight(rightExpr.get) - Some( - ExprOuterClass.Expr - .newBuilder() - .setBitwiseShiftRight(builder) - .build()) + ( + Some( + ExprOuterClass.Expr + .newBuilder() + .setBitwiseShiftRight(builder) + .build()), + CometExplainInfo.none) } else { - None + unsupported("SHIFT_RIGHT", Seq(leftInfo, rightInfo)) } case ShiftLeft(left, right) => - val leftExpr = exprToProtoInternal(left, inputs) - val rightExpr = if (left.dataType == LongType) { + val (leftExpr, leftInfo) = exprToProtoInternal(left, inputs) + val (rightExpr, rightInfo) = if (left.dataType == LongType) { // DataFusion bitwise shift left expression requires // same data type between left and right side exprToProtoInternal(Cast(right, LongType), inputs) @@ -1508,17 +1778,19 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde { builder.setLeft(leftExpr.get) builder.setRight(rightExpr.get) - Some( - ExprOuterClass.Expr - .newBuilder() - .setBitwiseShiftLeft(builder) - .build()) + ( + Some( + ExprOuterClass.Expr + .newBuilder() + .setBitwiseShiftLeft(builder) + .build()), + CometExplainInfo.none) } else { - None + unsupported("SHIFT_LEFT", Seq(leftInfo, rightInfo)) } case In(value, list) => - in(value, list, inputs, false) + in(value, list, inputs, false, "IN") case InSet(value, hset) => val valueDataType = value.dataType @@ -1527,128 +1799,154 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde { }.toSeq // Change `InSet` to `In` expression // We do Spark `InSet` optimization in native (DataFusion) side. - in(value, list, inputs, false) + in(value, list, inputs, false, "INSET") case Not(In(value, list)) => - in(value, list, inputs, true) + in(value, list, inputs, true, "NOT_IN") case Not(child) => - val childExpr = exprToProtoInternal(child, inputs) + val (childExpr, info) = exprToProtoInternal(child, inputs) if (childExpr.isDefined) { val builder = ExprOuterClass.Not.newBuilder() builder.setChild(childExpr.get) - Some( - ExprOuterClass.Expr - .newBuilder() - .setNot(builder) - .build()) + ( + Some( + ExprOuterClass.Expr + .newBuilder() + .setNot(builder) + .build()), + CometExplainInfo.none) } else { - None + unsupported("NOT", info) } case UnaryMinus(child, _) => - val childExpr = exprToProtoInternal(child, inputs) + val (childExpr, info) = exprToProtoInternal(child, inputs) if (childExpr.isDefined) { val builder = ExprOuterClass.Negative.newBuilder() builder.setChild(childExpr.get) - Some( - ExprOuterClass.Expr - .newBuilder() - .setNegative(builder) - .build()) + ( + Some( + ExprOuterClass.Expr + .newBuilder() + .setNegative(builder) + .build()), + CometExplainInfo.none) } else { - None + unsupported("UNARY_MINUS", info) } case a @ Coalesce(_) => - val exprChildren = a.children.map(exprToProtoInternal(_, inputs)) + val (exprChildren, info) = a.children.map(exprToProtoInternal(_, inputs)).unzip val childExpr = scalarExprToProto("coalesce", exprChildren: _*) // TODO: Remove this once we have new DataFusion release which includes // the fix: https://github.com/apache/arrow-datafusion/pull/9459 - castToProto(None, a.dataType, childExpr) + if (childExpr.isDefined) { + castToProto(None, a.dataType, childExpr) + } else { + unsupported("COALESCE", info) + } // With Spark 3.4, CharVarcharCodegenUtils.readSidePadding gets called to pad spaces for // char types. Use rpad to achieve the behavior. // See https://github.com/apache/spark/pull/38151 case StaticInvoke( - clz: Class[_], + _: Class[CharVarcharCodegenUtils], _: StringType, "readSidePadding", arguments, _, true, false, - true) if clz == classOf[CharVarcharCodegenUtils] && arguments.size == 2 => - val argsExpr = Seq( + true) if arguments.size == 2 => + val (argsExpr, argsInfo) = Seq( exprToProtoInternal(Cast(arguments(0), StringType), inputs), - exprToProtoInternal(arguments(1), inputs)) + exprToProtoInternal(arguments(1), inputs)).unzip if (argsExpr.forall(_.isDefined)) { val builder = ExprOuterClass.ScalarFunc.newBuilder() builder.setFunc("rpad") argsExpr.foreach(arg => builder.addArgs(arg.get)) - Some(ExprOuterClass.Expr.newBuilder().setScalarFunc(builder).build()) + ( + Some(ExprOuterClass.Expr.newBuilder().setScalarFunc(builder).build()), + CometExplainInfo.none) } else { - None + unsupported("STATICINVOKE_RPAD", argsInfo) } case KnownFloatingPointNormalized(NormalizeNaNAndZero(expr)) => + val name = "FP_NORMALIZED" val dataType = serializeDataType(expr.dataType) if (dataType.isEmpty) { - return None - } - exprToProtoInternal(expr, inputs).map { child => - val builder = ExprOuterClass.NormalizeNaNAndZero - .newBuilder() - .setChild(child) - .setDatatype(dataType.get) - ExprOuterClass.Expr.newBuilder().setNormalizeNanAndZero(builder).build() + return unsupported(name, CometExplainInfo(s"Unsupported datatype ${expr.dataType}")) } + val (ex, _) = exprToProtoInternal(expr, inputs) + ( + ex.map { child => + val builder = ExprOuterClass.NormalizeNaNAndZero + .newBuilder() + .setChild(child) + .setDatatype(dataType.get) + ExprOuterClass.Expr.newBuilder().setNormalizeNanAndZero(builder).build() + }, + CometExplainInfo.none) case s @ execution.ScalarSubquery(_, _) => val dataType = serializeDataType(s.dataType) if (dataType.isEmpty) { - return None + return unsupported( + "SCALAR_SUBQUERY", + CometExplainInfo(s"Unsupported datatype ${s.dataType}")) } val builder = ExprOuterClass.Subquery .newBuilder() .setId(s.exprId.id) .setDatatype(dataType.get) - Some(ExprOuterClass.Expr.newBuilder().setSubquery(builder).build()) + ( + Some(ExprOuterClass.Expr.newBuilder().setSubquery(builder).build()), + CometExplainInfo.none) case UnscaledValue(child) => - val childExpr = exprToProtoInternal(child, inputs) - scalarExprToProtoWithReturnType("unscaled_value", LongType, childExpr) + val (childExpr, info) = exprToProtoInternal(child, inputs) + scalarExprToProtoWithReturnType("unscaled_value", LongType, childExpr) match { + case Some(e) => (Some(e), CometExplainInfo.none) + case None => unsupported("UNSCALED_VALUE", info) + } case MakeDecimal(child, precision, scale, true) => - val childExpr = exprToProtoInternal(child, inputs) + val (childExpr, info) = exprToProtoInternal(child, inputs) scalarExprToProtoWithReturnType( "make_decimal", DecimalType(precision, scale), - childExpr) + childExpr) match { + case Some(e) => (Some(e), CometExplainInfo.none) + case None => unsupported("MAKE_DECIMAL", info) + } + case b @ BinaryExpression(_, _) if isBloomFilterMightContain(b) => val bloomFilter = b.left val value = b.right - val bloomFilterExpr = exprToProtoInternal(bloomFilter, inputs) - val valueExpr = exprToProtoInternal(value, inputs) + val (bloomFilterExpr, bloomInfo) = exprToProtoInternal(bloomFilter, inputs) + val (valueExpr, valueInfo) = exprToProtoInternal(value, inputs) if (bloomFilterExpr.isDefined && valueExpr.isDefined) { val builder = ExprOuterClass.BloomFilterMightContain.newBuilder() builder.setBloomFilter(bloomFilterExpr.get) builder.setValue(valueExpr.get) - Some( - ExprOuterClass.Expr - .newBuilder() - .setBloomFilterMightContain(builder) - .build()) + ( + Some( + ExprOuterClass.Expr + .newBuilder() + .setBloomFilterMightContain(builder) + .build()), + CometExplainInfo.none) } else { - None + unsupported("BLOOMFILTER", Seq(bloomInfo, valueInfo)) } - case e => - emitWarning(s"unsupported Spark expression: '$e' of class '${e.getClass.getName}") - None + case _ => + unsupported(expr.prettyName, CometExplainInfo(s"${expr.prettyName} is not supported")) } } @@ -1656,13 +1954,20 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde { srcStr: Expression, trimStr: Option[Expression], inputs: Seq[Attribute], - trimType: String): Option[Expr] = { - val srcExpr = exprToProtoInternal(Cast(srcStr, StringType), inputs) + trimType: String): (Option[Expr], CometExplainInfo) = { + val (srcExpr, srcInfo) = exprToProtoInternal(Cast(srcStr, StringType), inputs) if (trimStr.isDefined) { - val trimExpr = exprToProtoInternal(Cast(trimStr.get, StringType), inputs) - scalarExprToProto(trimType, srcExpr, trimExpr) + val (trimExpr, trimInfo) = + exprToProtoInternal(Cast(trimStr.get, StringType), inputs) + scalarExprToProto(trimType, srcExpr, trimExpr) match { + case Some(e) => (Some(e), CometExplainInfo.none) + case None => unsupported(trimType.toUpperCase(), Seq(srcInfo, trimInfo)) + } } else { - scalarExprToProto(trimType, srcExpr) + scalarExprToProto(trimType, srcExpr) match { + case Some(e) => (Some(e), CometExplainInfo.none) + case None => unsupported(trimType.toUpperCase(), srcInfo) + } } } @@ -1670,21 +1975,24 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde { value: Expression, list: Seq[Expression], inputs: Seq[Attribute], - negate: Boolean): Option[Expr] = { - val valueExpr = exprToProtoInternal(value, inputs) - val listExprs = list.map(exprToProtoInternal(_, inputs)) + negate: Boolean, + displayName: String): (Option[Expr], CometExplainInfo) = { + val (valueExpr, valueInfo) = exprToProtoInternal(value, inputs) + val (listExprs, listInfos) = list.map(exprToProtoInternal(_, inputs)).unzip if (valueExpr.isDefined && listExprs.forall(_.isDefined)) { val builder = ExprOuterClass.In.newBuilder() builder.setInValue(valueExpr.get) builder.addAllLists(listExprs.map(_.get).asJava) builder.setNegated(negate) - Some( - ExprOuterClass.Expr - .newBuilder() - .setIn(builder) - .build()) + ( + Some( + ExprOuterClass.Expr + .newBuilder() + .setIn(builder) + .build()), + CometExplainInfo.none) } else { - None + unsupported(displayName, listExprs ++ Seq(valueInfo)) } } @@ -1747,43 +2055,43 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde { * The converted Comet native operator for the input `op`, or `None` if the `op` cannot be * converted to a native operator. */ - def operator2Proto(op: SparkPlan, childOp: Operator*): Option[Operator] = { + def operator2Proto(op: SparkPlan, childOp: Operator*): (Option[Operator], CometExplainInfo) = { val result = OperatorOuterClass.Operator.newBuilder() childOp.foreach(result.addChildren) op match { case ProjectExec(projectList, child) if isCometOperatorEnabled(op.conf, "project") => - val exprs = projectList.map(exprToProto(_, child.output)) + val (exprs, exprsInfo) = projectList.map(exprToProto(_, child.output)).unzip if (exprs.forall(_.isDefined) && childOp.nonEmpty) { val projectBuilder = OperatorOuterClass.Projection .newBuilder() .addAllProjectList(exprs.map(_.get).asJava) - Some(result.setProjection(projectBuilder).build()) + (Some(result.setProjection(projectBuilder).build()), null) } else { - None + unsupported("CometProject", exprsInfo) } case FilterExec(condition, child) if isCometOperatorEnabled(op.conf, "filter") => - val cond = exprToProto(condition, child.output) + val (cond, info) = exprToProto(condition, child.output) if (cond.isDefined && childOp.nonEmpty) { val filterBuilder = OperatorOuterClass.Filter.newBuilder().setPredicate(cond.get) - Some(result.setFilter(filterBuilder).build()) + (Some(result.setFilter(filterBuilder).build()), CometExplainInfo.none) } else { - None + unsupported("CometFilter", info) } case SortExec(sortOrder, _, child, _) if isCometOperatorEnabled(op.conf, "sort") => - val sortOrders = sortOrder.map(exprToProto(_, child.output)) + val (sortOrders, sortOrdersInfo) = sortOrder.map(exprToProto(_, child.output)).unzip if (sortOrders.forall(_.isDefined) && childOp.nonEmpty) { val sortBuilder = OperatorOuterClass.Sort .newBuilder() .addAllSortOrders(sortOrders.map(_.get).asJava) - Some(result.setSort(sortBuilder).build()) + (Some(result.setSort(sortBuilder).build()), CometExplainInfo.none) } else { - None + unsupported("CometSort", sortOrdersInfo) } case LocalLimitExec(limit, _) if isCometOperatorEnabled(op.conf, "local_limit") => @@ -1794,9 +2102,9 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde { .newBuilder() .setLimit(limit) .setOffset(0) - Some(result.setLimit(limitBuilder).build()) + (Some(result.setLimit(limitBuilder).build()), CometExplainInfo.none) } else { - None + unsupported("CometLocalLimit", CometExplainInfo("No child operator")) } case globalLimitExec: GlobalLimitExec if isCometOperatorEnabled(op.conf, "global_limit") => @@ -1811,22 +2119,23 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde { limitBuilder.setLimit(globalLimitExec.limit) limitBuilder.setOffset(0) - Some(result.setLimit(limitBuilder).build()) + (Some(result.setLimit(limitBuilder).build()), CometExplainInfo.none) } else { - None + unsupported("CometGlobalLimit", CometExplainInfo("No child operator")) } case ExpandExec(projections, _, child) if isCometOperatorEnabled(op.conf, "expand") => - val projExprs = projections.flatMap(_.map(exprToProto(_, child.output))) + val (projExprs, projInfos) = + projections.flatMap(_.map(exprToProto(_, child.output))).unzip if (projExprs.forall(_.isDefined) && childOp.nonEmpty) { val expandBuilder = OperatorOuterClass.Expand .newBuilder() .addAllProjectList(projExprs.map(_.get).asJava) .setNumExprPerProject(projections.head.size) - Some(result.setExpand(expandBuilder).build()) + (Some(result.setExpand(expandBuilder).build()), CometExplainInfo.none) } else { - None + unsupported("CometExpand", projInfos) } case HashAggregateExec( @@ -1840,10 +2149,11 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde { resultExpressions, child) if isCometOperatorEnabled(op.conf, "aggregate") => if (groupingExpressions.isEmpty && aggregateExpressions.isEmpty) { - return None + return unsupported("CometHashAggregate", CometExplainInfo("No group by or aggregation")) } - val groupingExprs = groupingExpressions.map(exprToProto(_, child.output)) + val (groupingExprs, groupingExprsInfos) = + groupingExpressions.map(exprToProto(_, child.output)).unzip // In some of the cases, the aggregateExpressions could be empty. // For example, if the aggregate functions only have group by or if the aggregate @@ -1863,37 +2173,78 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde { val hashAggBuilder = OperatorOuterClass.HashAggregate.newBuilder() hashAggBuilder.addAllGroupingExprs(groupingExprs.map(_.get).asJava) val attributes = groupingExpressions.map(_.toAttribute) ++ aggregateAttributes - val resultExprs = resultExpressions.map(exprToProto(_, attributes)) + val (resultExprs, _) = resultExpressions.map(exprToProto(_, attributes)).unzip if (resultExprs.exists(_.isEmpty)) { - emitWarning(s"Unsupported result expressions found in: ${resultExpressions}") - return None + val msg = s"Unsupported result expressions found in: ${resultExpressions}" + emitWarning(msg) + return unsupported("CometHashAggregate", CometExplainInfo(msg)) } hashAggBuilder.addAllResultExprs(resultExprs.map(_.get).asJava) - Some(result.setHashAgg(hashAggBuilder).build()) + (Some(result.setHashAgg(hashAggBuilder).build()), CometExplainInfo.none) } else { val modes = aggregateExpressions.map(_.mode).distinct if (modes.size != 1) { // This shouldn't happen as all aggregation expressions should share the same mode. // Fallback to Spark nevertheless here. - return None + return unsupported( + "CometHashAggregate", + CometExplainInfo("All aggregate expressions do not have the same mode")) } val mode = modes.head match { case Partial => CometAggregateMode.Partial case Final => CometAggregateMode.Final - case _ => return None + case _ => + return unsupported( + "CometHashAggregate", + CometExplainInfo(s"Unsupported aggregation mode ${modes.head}")) } - // In final mode, the aggregate expressions are bound to the output of the - // child and partial aggregate expressions buffer attributes produced by partial - // aggregation. This is done in Spark `HashAggregateExec` internally. In Comet, - // we don't have to do this because we don't use the merging expression. - val binding = mode != CometAggregateMode.Final - // `output` is only used when `binding` is true (i.e., non-Final) - val output = child.output + val output = mode match { + case CometAggregateMode.Partial => child.output + case CometAggregateMode.Final => + // Assuming `Final` always follows `Partial` aggregation, this find the first + // `Partial` aggregation and get the input attributes from it. + // During finding partial aggregation, we must ensure all traversed op are + // native operators. If not, we should fallback to Spark. + var seenNonNativeOp = false + var partialAggInput: Option[Seq[Attribute]] = None + child.transformDown { + case op if !op.isInstanceOf[CometPlan] => + seenNonNativeOp = true + op + case op @ CometHashAggregateExec(_, _, _, _, input, Some(Partial), _, _) => + if (!seenNonNativeOp && partialAggInput.isEmpty) { + partialAggInput = Some(input) + } + op + } - val aggExprs = aggregateExpressions.map(aggExprToProto(_, output, binding)) + if (partialAggInput.isDefined) { + partialAggInput.get + } else { + return unsupported( + "CometHashAggregate", + CometExplainInfo("No input for partial aggregate")) + } + case _ => + return unsupported( + "CometHashAggregate", + CometExplainInfo(s"Unsupported mode $mode")) + } + val binding = if (mode == CometAggregateMode.Final) { +// In final mode, the aggregate expressions are bound to the output of the +// child and partial aggregate expressions buffer attributes produced by partial +// aggregation. This is done in Spark `HashAggregateExec` internally. In Comet, +// we don't have to do this because we don't use the merging expression. + false + } else { + true + } + + val (aggExprs, aggExprsInfos) = + aggregateExpressions.map(aggExprToProto(_, output, binding)).unzip if (childOp.nonEmpty && groupingExprs.forall(_.isDefined) && aggExprs.forall(_.isDefined)) { val hashAggBuilder = OperatorOuterClass.HashAggregate.newBuilder() @@ -1901,17 +2252,18 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde { hashAggBuilder.addAllAggExprs(aggExprs.map(_.get).asJava) if (mode == CometAggregateMode.Final) { val attributes = groupingExpressions.map(_.toAttribute) ++ aggregateAttributes - val resultExprs = resultExpressions.map(exprToProto(_, attributes)) + val (resultExprs, _) = resultExpressions.map(exprToProto(_, attributes)).unzip if (resultExprs.exists(_.isEmpty)) { - emitWarning(s"Unsupported result expressions found in: ${resultExpressions}") - return None + val msg = s"Unsupported result expressions found in: ${resultExpressions}" + emitWarning(msg) + return unsupported("CometHashAggregate", CometExplainInfo(msg)) } hashAggBuilder.addAllResultExprs(resultExprs.map(_.get).asJava) } hashAggBuilder.setModeValue(mode.getNumber) - Some(result.setHashAgg(hashAggBuilder).build()) + (Some(result.setHashAgg(hashAggBuilder).build()), CometExplainInfo.none) } else { - None + unsupported("CometHashAggregate", aggExprsInfos ++ groupingExprsInfos) } } @@ -1922,19 +2274,19 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde { join.isInstanceOf[ShuffledHashJoinExec]) && !(isCometOperatorEnabled(op.conf, "broadcast_hash_join") && join.isInstanceOf[BroadcastHashJoinExec])) { - return None + return unsupported("HashJoin", s"Invalid hash join type ${join.nodeName}") } if (join.buildSide == BuildRight) { // DataFusion HashJoin assumes build side is always left. // TODO: support BuildRight - return None + return unsupported("HashJoin", "BuildRight is not supported") } val condition = join.condition.map { cond => - val condProto = exprToProto(cond, join.left.output ++ join.right.output) + val (condProto, condInfo) = exprToProto(cond, join.left.output ++ join.right.output) if (condProto.isEmpty) { - return None + return unsupported("HashJoin", condInfo) } condProto.get } @@ -1946,11 +2298,15 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde { case FullOuter => JoinType.FullOuter case LeftSemi => JoinType.LeftSemi case LeftAnti => JoinType.LeftAnti - case _ => return None // Spark doesn't support other join types + case _ => + return unsupported( + "HashJoin", + s"Unsupported join type ${join.joinType}" + ) // Spark doesn't support other join types } - val leftKeys = join.leftKeys.map(exprToProto(_, join.left.output)) - val rightKeys = join.rightKeys.map(exprToProto(_, join.right.output)) + val (leftKeys, leftInfos) = join.leftKeys.map(exprToProto(_, join.left.output)).unzip + val (rightKeys, rightInfos) = join.rightKeys.map(exprToProto(_, join.right.output)).unzip if (leftKeys.forall(_.isDefined) && rightKeys.forall(_.isDefined) && @@ -1961,9 +2317,9 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde { .addAllLeftJoinKeys(leftKeys.map(_.get).asJava) .addAllRightJoinKeys(rightKeys.map(_.get).asJava) condition.foreach(joinBuilder.setCondition) - Some(result.setHashJoin(joinBuilder).build()) + (Some(result.setHashJoin(joinBuilder).build()), CometExplainInfo.none) } else { - None + unsupported("HashJoin", leftInfos ++ rightInfos) } case join: SortMergeJoinExec if isCometOperatorEnabled(op.conf, "sort_merge_join") => @@ -1988,7 +2344,9 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde { // TODO: Support SortMergeJoin with join condition after new DataFusion release if (join.condition.isDefined) { - return None + return unsupported( + op.nodeName, + CometExplainInfo("Sort merge join with a join condition is not supported")) } val joinType = join.joinType match { @@ -1998,14 +2356,20 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde { case FullOuter => JoinType.FullOuter case LeftSemi => JoinType.LeftSemi case LeftAnti => JoinType.LeftAnti - case _ => return None // Spark doesn't support other join types + case _ => + return unsupported( + op.nodeName, + CometExplainInfo(s"Unsupported join type ${join.joinType}") + ) // Spark doesn't support other join types } - val leftKeys = join.leftKeys.map(exprToProto(_, join.left.output)) - val rightKeys = join.rightKeys.map(exprToProto(_, join.right.output)) + val (leftKeys, leftInfo) = join.leftKeys.map(exprToProto(_, join.left.output)).unzip + val (rightKeys, rightInfo) = join.rightKeys.map(exprToProto(_, join.right.output)).unzip - val sortOptions = getKeyOrdering(join.leftKeys, join.left.outputOrdering) - .map(exprToProto(_, join.left.output)) + val (sortOptions, sortOptionsInfo) = + getKeyOrdering(join.leftKeys, join.left.outputOrdering) + .map(exprToProto(_, join.left.output)) + .unzip if (sortOptions.forall(_.isDefined) && leftKeys.forall(_.isDefined) && @@ -2017,9 +2381,10 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde { .addAllSortOptions(sortOptions.map(_.get).asJava) .addAllLeftJoinKeys(leftKeys.map(_.get).asJava) .addAllRightJoinKeys(rightKeys.map(_.get).asJava) - Some(result.setSortMergeJoin(joinBuilder).build()) + (Some(result.setSortMergeJoin(joinBuilder).build()), CometExplainInfo.none) } else { - None + + unsupported(op.nodeName, leftInfo ++ rightInfo ++ sortOptionsInfo) } case op if isCometSink(op) => @@ -2036,12 +2401,13 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde { // Sink operators don't have children result.clearChildren() - Some(result.setScan(scanBuilder).build()) + (Some(result.setScan(scanBuilder).build()), CometExplainInfo.none) } else { // There are unsupported scan type - emitWarning( - s"unsupported Comet operator: ${op.nodeName}, due to unsupported data types above") - None + val msg = + s"unsupported Comet operator: ${op.nodeName}, due to unsupported data types above" + emitWarning(msg) + unsupported(op.nodeName, msg) } case op => @@ -2051,7 +2417,7 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde { if (!op.nodeName.contains("Comet") && !op.isInstanceOf[ShuffleExchangeExec]) { emitWarning(s"unsupported Spark operator: ${op.nodeName}") } - None + unsupported(op.nodeName, CometExplainInfo.none) } } @@ -2067,7 +2433,6 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde { case _: CometRowToColumnarExec => true case _: CometSinkPlaceHolder => true case _: CoalesceExec => true - case _: CollectLimitExec => true case _: UnionExec => true case _: ShuffleExchangeExec => true case ShuffleQueryStageExec(_, _: CometShuffleExchangeExec, _) => true @@ -2093,7 +2458,9 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde { * Check if the datatypes of shuffle input are supported. This is used for Columnar shuffle * which supports struct/array. */ - def supportPartitioningTypes(inputs: Seq[Attribute]): Boolean = { + def supportPartitioningTypes( + inputs: Seq[Attribute], + partitioning: Partitioning): (Boolean, String) = { def supportedDataType(dt: DataType): Boolean = dt match { case _: ByteType | _: ShortType | _: IntegerType | _: LongType | _: FloatType | _: DoubleType | _: StringType | _: BinaryType | _: TimestampType | _: DecimalType | @@ -2120,15 +2487,31 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde { // Check if the datatypes of shuffle input are supported. val supported = inputs.forall(attr => supportedDataType(attr.dataType)) if (!supported) { - emitWarning(s"unsupported Spark partitioning: ${inputs.map(_.dataType)}") + val msg = s"unsupported Spark partitioning: ${inputs.map(_.dataType)}" + emitWarning(msg) + (false, msg) + } else { + partitioning match { + case HashPartitioning(expressions, _) => + (expressions.map(QueryPlanSerde.exprToProto(_, inputs)).forall(_._1.isDefined), null) + case SinglePartition => (true, null) + case _: RoundRobinPartitioning => (true, null) + case RangePartitioning(ordering, _) => + (ordering.map(QueryPlanSerde.exprToProto(_, inputs)).forall(_._1.isDefined), null) + case other => + val msg = s"unsupported Spark partitioning: ${other.getClass.getName}" + emitWarning(msg) + (false, msg) + } } - supported } /** * Whether the given Spark partitioning is supported by Comet. */ - def supportPartitioning(inputs: Seq[Attribute], partitioning: Partitioning): Boolean = { + def supportPartitioning( + inputs: Seq[Attribute], + partitioning: Partitioning): (Boolean, String) = { def supportedDataType(dt: DataType): Boolean = dt match { case _: ByteType | _: ShortType | _: IntegerType | _: LongType | _: FloatType | _: DoubleType | _: StringType | _: BinaryType | _: TimestampType | _: DecimalType | @@ -2143,17 +2526,20 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde { val supported = inputs.forall(attr => supportedDataType(attr.dataType)) if (!supported) { - emitWarning(s"unsupported Spark partitioning: ${inputs.map(_.dataType)}") - false + val msg = s"unsupported Spark partitioning: ${inputs.map(_.dataType)}" + emitWarning(msg) + (false, msg) } else { partitioning match { case HashPartitioning(expressions, _) => - expressions.map(QueryPlanSerde.exprToProto(_, inputs)).forall(_.isDefined) - case SinglePartition => true + (expressions.map(QueryPlanSerde.exprToProto(_, inputs)).forall(_._1.isDefined), null) + case SinglePartition => (true, null) case other => - emitWarning(s"unsupported Spark partitioning: ${other.getClass.getName}") - false + val msg = s"unsupported Spark partitioning: ${other.getClass.getName}" + emitWarning(msg) + (false, msg) } } } + } diff --git a/spark/src/main/scala/org/apache/comet/shims/ShimCometSparkSessionExtensions.scala b/spark/src/main/scala/org/apache/comet/shims/ShimCometSparkSessionExtensions.scala index 85c6413e13..4291d66054 100644 --- a/spark/src/main/scala/org/apache/comet/shims/ShimCometSparkSessionExtensions.scala +++ b/spark/src/main/scala/org/apache/comet/shims/ShimCometSparkSessionExtensions.scala @@ -20,7 +20,7 @@ package org.apache.comet.shims import org.apache.spark.sql.connector.expressions.aggregate.Aggregation -import org.apache.spark.sql.execution.{LimitExec, SparkPlan} +import org.apache.spark.sql.execution.{LimitExec, QueryExecution, SparkPlan} import org.apache.spark.sql.execution.datasources.v2.parquet.ParquetScan trait ShimCometSparkSessionExtensions { @@ -49,4 +49,17 @@ object ShimCometSparkSessionExtensions { .filter(_.isInstanceOf[Int]) .map(_.asInstanceOf[Int]) .headOption + + def supportsExtendedExplainInfo(qe: QueryExecution): Boolean = { + try { + // Look for QueryExecution.extendedExplainInfo(scala.Function1[String, Unit], SparkPlan) + qe.getClass.getDeclaredMethod( + "extendedExplainInfo", + classOf[String => Unit], + classOf[SparkPlan]) + } catch { + case _: NoSuchMethodException | _: SecurityException => return false + } + true + } } diff --git a/spark/src/main/scala/org/apache/spark/sql/ExtendedExplainGenerator.scala b/spark/src/main/scala/org/apache/spark/sql/ExtendedExplainGenerator.scala new file mode 100644 index 0000000000..e72c894b3d --- /dev/null +++ b/spark/src/main/scala/org/apache/spark/sql/ExtendedExplainGenerator.scala @@ -0,0 +1,32 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.spark.sql + +import org.apache.spark.sql.execution.SparkPlan + +/** + * A trait for a session extension to implement that provides addition explain plan information. + */ + +trait ExtendedExplainGenerator { + def title: String + + def generateExtendedInfo(plan: SparkPlan): String +} diff --git a/spark/src/main/scala/org/apache/spark/sql/comet/CometExecUtils.scala b/spark/src/main/scala/org/apache/spark/sql/comet/CometExecUtils.scala index 5931920a20..8bec032aa4 100644 --- a/spark/src/main/scala/org/apache/spark/sql/comet/CometExecUtils.scala +++ b/spark/src/main/scala/org/apache/spark/sql/comet/CometExecUtils.scala @@ -67,7 +67,7 @@ object CometExecUtils { child: SparkPlan, limit: Int): Option[Operator] = { getTopKNativePlan(outputAttributes, sortOrder, child, limit).flatMap { topK => - val exprs = projectList.map(exprToProto(_, child.output)) + val (exprs, exprsInfo) = projectList.map(exprToProto(_, child.output)).unzip if (exprs.forall(_.isDefined)) { val projectBuilder = OperatorOuterClass.Projection.newBuilder() @@ -127,7 +127,7 @@ object CometExecUtils { if (scanTypes.length == outputAttributes.length) { scanBuilder.addAllFields(scanTypes.asJava) - val sortOrders = sortOrder.map(exprToProto(_, child.output)) + val (sortOrders, sortInfos) = sortOrder.map(exprToProto(_, child.output)).unzip if (sortOrders.forall(_.isDefined)) { val sortBuilder = OperatorOuterClass.Sort.newBuilder() diff --git a/spark/src/main/scala/org/apache/spark/sql/comet/CometTakeOrderedAndProjectExec.scala b/spark/src/main/scala/org/apache/spark/sql/comet/CometTakeOrderedAndProjectExec.scala index 26ec401ed6..fc6a5927bc 100644 --- a/spark/src/main/scala/org/apache/spark/sql/comet/CometTakeOrderedAndProjectExec.scala +++ b/spark/src/main/scala/org/apache/spark/sql/comet/CometTakeOrderedAndProjectExec.scala @@ -29,6 +29,7 @@ import org.apache.spark.sql.execution.metric.{SQLMetrics, SQLShuffleReadMetricsR import org.apache.spark.sql.execution.metric.SQLMetric import org.apache.spark.sql.vectorized.ColumnarBatch +import org.apache.comet.CometExplainInfo import org.apache.comet.serde.QueryPlanSerde.exprToProto import org.apache.comet.shims.ShimCometTakeOrderedAndProjectExec @@ -122,10 +123,14 @@ case class CometTakeOrderedAndProjectExec( object CometTakeOrderedAndProjectExec extends ShimCometTakeOrderedAndProjectExec { // TODO: support offset for Spark 3.4 - def isSupported(plan: TakeOrderedAndProjectExec): Boolean = { + def isSupported(plan: TakeOrderedAndProjectExec): (Boolean, CometExplainInfo) = { val exprs = plan.projectList.map(exprToProto(_, plan.child.output)) val sortOrders = plan.sortOrder.map(exprToProto(_, plan.child.output)) - exprs.forall(_.isDefined) && sortOrders.forall(_.isDefined) && getOffset(plan).getOrElse( - 0) == 0 + val isSupportedForAll = exprs.forall(_._1.isDefined) && sortOrders.forall(_._1.isDefined) + if (isSupportedForAll) { + (true, CometExplainInfo.none) + } else { + (false, CometExplainInfo("TakeOrderedAndProject", exprs.map(_._2) ++ sortOrders.map(_._2))) + } } } diff --git a/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometShuffleExchangeExec.scala b/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometShuffleExchangeExec.scala index 232b6bf17f..c67299519a 100644 --- a/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometShuffleExchangeExec.scala +++ b/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometShuffleExchangeExec.scala @@ -519,7 +519,10 @@ class CometShuffleWriteProcessor( partitioning.setNumPartitions(outputPartitioning.numPartitions) val partitionExprs = hashPartitioning.expressions - .flatMap(e => QueryPlanSerde.exprToProto(e, outputAttributes)) + .flatMap(e => { + val (op, _) = QueryPlanSerde.exprToProto(e, outputAttributes) + op + }) if (partitionExprs.length != hashPartitioning.expressions.length) { throw new UnsupportedOperationException( diff --git a/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala b/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala index 803f30bed0..87c2126af6 100644 --- a/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala +++ b/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala @@ -1383,4 +1383,55 @@ class CometExpressionSuite extends CometTestBase with AdaptiveSparkPlanHelper { testCastedColumn(inputValues = Seq("car", "Truck")) } + test("explain comet") { + assume(isSpark34Plus) + withSQLConf( + SQLConf.ANSI_ENABLED.key -> "false", + CometConf.COMET_ENABLED.key -> "true", + CometConf.COMET_EXEC_ENABLED.key -> "true", + CometConf.COMET_EXEC_ALL_EXPR_ENABLED.key -> "true", + CometConf.COMET_EXEC_ALL_OPERATOR_ENABLED.key -> "true", + "spark.sql.extendedExplainProvider" -> "org.apache.comet.ExtendedExplainInfo") { + val table = "test" + withTable(table) { + sql(s"create table $table(c0 int, c1 int , c2 float) using parquet") + sql(s"insert into $table values(0, 1, 100.000001)") + + Seq( + ( + s"SELECT cast(make_interval(c0, c1, c0, c1, c0, c0, c2) as string) as C from $table", + "make_interval is not supported"), + ( + "SELECT " + + "date_part('YEAR', make_interval(c0, c1, c0, c1, c0, c0, c2))" + + " + " + + "date_part('MONTH', make_interval(c0, c1, c0, c1, c0, c0, c2))" + + s" as yrs_and_mths from $table", + "extractintervalyears is not supported\n" + + "extractintervalmonths is not supported"), + ( + s"SELECT sum(c0), sum(c2) from $table group by c1", + "Native shuffle is not enabled\n" + + "AQEShuffleRead is not supported"), + ( + "SELECT A.c1, A.sum_c0, A.sum_c2, B.casted from " + + s"(SELECT c1, sum(c0) as sum_c0, sum(c2) as sum_c2 from $table group by c1) as A, " + + s"(SELECT c1, cast(make_interval(c0, c1, c0, c1, c0, c0, c2) as string) as casted from $table) as B " + + "where A.c1 = B.c1 ", + "Native shuffle is not enabled\n" + + "AQEShuffleRead is not supported\n" + + "make_interval is not supported\n" + + "BroadcastExchange is not supported\n" + + "BroadcastHashJoin is not supported")) + .foreach(test => { + val qry = test._1 + val expected = test._2 + val df = sql(qry) + df.collect() // force an execution + checkSparkAnswerAndCompareExplainPlan(df, expected) + }) + } + } + } + } diff --git a/spark/src/test/scala/org/apache/spark/sql/CometTPCQueryBase.scala b/spark/src/test/scala/org/apache/spark/sql/CometTPCQueryBase.scala index 8f83ac04b3..ad989f58cc 100644 --- a/spark/src/test/scala/org/apache/spark/sql/CometTPCQueryBase.scala +++ b/spark/src/test/scala/org/apache/spark/sql/CometTPCQueryBase.scala @@ -46,6 +46,9 @@ trait CometTPCQueryBase extends Logging { .set("spark.sql.autoBroadcastJoinThreshold", (20 * 1024 * 1024).toString) .set("spark.sql.crossJoin.enabled", "true") .setIfMissing("parquet.enable.dictionary", "true") + .set( + "spark.shuffle.manager", + "org.apache.spark.sql.comet.execution.shuffle.CometShuffleManager") val sparkSession = SparkSession.builder .config(conf) diff --git a/spark/src/test/scala/org/apache/spark/sql/CometTPCQueryListBase.scala b/spark/src/test/scala/org/apache/spark/sql/CometTPCQueryListBase.scala index 1f28b76a1e..e0d28c647b 100644 --- a/spark/src/test/scala/org/apache/spark/sql/CometTPCQueryListBase.scala +++ b/spark/src/test/scala/org/apache/spark/sql/CometTPCQueryListBase.scala @@ -31,12 +31,15 @@ import org.apache.spark.sql.comet.CometExec import org.apache.spark.sql.execution.SparkPlan import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper -import org.apache.comet.CometConf +import org.apache.comet.{CometConf, ExtendedExplainInfo} +import org.apache.comet.shims.ShimCometSparkSessionExtensions +import org.apache.comet.shims.ShimCometSparkSessionExtensions.supportsExtendedExplainInfo trait CometTPCQueryListBase extends CometTPCQueryBase with AdaptiveSparkPlanHelper - with SQLHelper { + with SQLHelper + with ShimCometSparkSessionExtensions { var output: Option[OutputStream] = None def main(args: Array[String]): Unit = { @@ -84,12 +87,16 @@ trait CometTPCQueryListBase withSQLConf( CometConf.COMET_ENABLED.key -> "true", CometConf.COMET_EXEC_ENABLED.key -> "true", + CometConf.COMET_EXEC_SHUFFLE_ENABLED.key -> "true", CometConf.COMET_EXEC_ALL_OPERATOR_ENABLED.key -> "true") { val df = cometSpark.sql(queryString) val cometPlans = mutable.HashSet.empty[String] - stripAQEPlan(df.queryExecution.executedPlan).foreach { case op: CometExec => - cometPlans += s"${op.nodeName}" + val executedPlan = df.queryExecution.executedPlan + stripAQEPlan(executedPlan).foreach { + case op: CometExec => + cometPlans += s"${op.nodeName}" + case _ => } if (cometPlans.nonEmpty) { @@ -98,6 +105,11 @@ trait CometTPCQueryListBase } else { out.println(s"Query: $name$nameSuffix. Comet Exec: Disabled") } + if (supportsExtendedExplainInfo(df.queryExecution)) { + out.println( + s"Query: $name$nameSuffix: ExplainInfo:\n" + + s"${new ExtendedExplainInfo().generateExtendedInfo(executedPlan)}\n") + } } } } diff --git a/spark/src/test/scala/org/apache/spark/sql/CometTestBase.scala b/spark/src/test/scala/org/apache/spark/sql/CometTestBase.scala index de5866580b..24cff29d16 100644 --- a/spark/src/test/scala/org/apache/spark/sql/CometTestBase.scala +++ b/spark/src/test/scala/org/apache/spark/sql/CometTestBase.scala @@ -25,6 +25,7 @@ import scala.reflect.runtime.universe.TypeTag import org.scalatest.BeforeAndAfterEach +import org.apache.commons.lang3.StringUtils import org.apache.hadoop.fs.Path import org.apache.parquet.column.ParquetProperties import org.apache.parquet.example.data.Group @@ -36,7 +37,7 @@ import org.apache.spark._ import org.apache.spark.internal.config.{MEMORY_OFFHEAP_ENABLED, MEMORY_OFFHEAP_SIZE, SHUFFLE_MANAGER} import org.apache.spark.sql.comet.{CometBatchScanExec, CometBroadcastExchangeExec, CometExec, CometRowToColumnarExec, CometScanExec, CometScanWrapper, CometSinkPlaceHolder} import org.apache.spark.sql.comet.execution.shuffle.{CometColumnarShuffle, CometNativeShuffle, CometShuffleExchangeExec} -import org.apache.spark.sql.execution.{ColumnarToRowExec, InputAdapter, SparkPlan, WholeStageCodegenExec} +import org.apache.spark.sql.execution.{ColumnarToRowExec, ExtendedMode, InputAdapter, SparkPlan, WholeStageCodegenExec} import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper import org.apache.spark.sql.internal._ import org.apache.spark.sql.test._ @@ -45,6 +46,8 @@ import org.apache.spark.sql.types.StructType import org.apache.comet._ import org.apache.comet.CometSparkSessionExtensions.isSpark34Plus +import org.apache.comet.shims.ShimCometSparkSessionExtensions +import org.apache.comet.shims.ShimCometSparkSessionExtensions.supportsExtendedExplainInfo /** * Base class for testing. This exists in `org.apache.spark.sql` since [[SQLTestUtils]] is @@ -54,7 +57,8 @@ abstract class CometTestBase extends QueryTest with SQLTestUtils with BeforeAndAfterEach - with AdaptiveSparkPlanHelper { + with AdaptiveSparkPlanHelper + with ShimCometSparkSessionExtensions { import testImplicits._ protected val shuffleManager: String = @@ -215,6 +219,30 @@ abstract class CometTestBase checkAnswerWithTol(dfComet, expected, absTol: Double) } + protected def checkSparkAnswerAndCompareExplainPlan( + df: DataFrame, + expectedInfo: String): Unit = { + var expected: Array[Row] = Array.empty + var dfSpark: Dataset[Row] = null + withSQLConf( + CometConf.COMET_ENABLED.key -> "false", + "spark.sql.extendedExplainProvider" -> "") { + dfSpark = Dataset.ofRows(spark, df.logicalPlan) + expected = dfSpark.collect() + } + val dfComet = Dataset.ofRows(spark, df.logicalPlan) + checkAnswer(dfComet, expected) + val diff = StringUtils.difference( + dfSpark.queryExecution.explainString(ExtendedMode), + dfComet.queryExecution.explainString(ExtendedMode)) + if (supportsExtendedExplainInfo(dfSpark.queryExecution)) { + assert(diff.contains(expectedInfo)) + } + val extendedInfo = + new ExtendedExplainInfo().generateExtendedInfo(dfComet.queryExecution.executedPlan) + assert(extendedInfo.equalsIgnoreCase(expectedInfo)) + } + private var _spark: SparkSession = _ protected implicit def spark: SparkSession = _spark protected implicit def sqlContext: SQLContext = _spark.sqlContext