From c528cefd58b4494525dfb492ea64a9baa5291603 Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Mon, 19 Aug 2024 17:31:11 -0600 Subject: [PATCH 1/6] plugin register session extension --- native/Cargo.lock | 45 ++++++++++++------- .../main/scala/org/apache/spark/Plugins.scala | 15 +++++++ 2 files changed, 45 insertions(+), 15 deletions(-) diff --git a/native/Cargo.lock b/native/Cargo.lock index 27bc3828c..3843e46a9 100644 --- a/native/Cargo.lock +++ b/native/Cargo.lock @@ -804,7 +804,8 @@ dependencies = [ [[package]] name = "datafusion" version = "41.0.0" -source = "git+https://github.com/apache/datafusion.git?rev=41.0.0-rc1#b10b820acb6ad92b5d69810e3d4de0ef6f2d6a87" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e4fd4a99fc70d40ef7e52b243b4a399c3f8d353a40d5ecb200deee05e49c61bb" dependencies = [ "ahash", "arrow", @@ -851,7 +852,8 @@ dependencies = [ [[package]] name = "datafusion-catalog" version = "41.0.0" -source = "git+https://github.com/apache/datafusion.git?rev=41.0.0-rc1#b10b820acb6ad92b5d69810e3d4de0ef6f2d6a87" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e13b3cfbd84c6003594ae1972314e3df303a27ce8ce755fcea3240c90f4c0529" dependencies = [ "arrow-schema", "async-trait", @@ -949,7 +951,8 @@ dependencies = [ [[package]] name = "datafusion-common" version = "41.0.0" -source = "git+https://github.com/apache/datafusion.git?rev=41.0.0-rc1#b10b820acb6ad92b5d69810e3d4de0ef6f2d6a87" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "44fdbc877e3e40dcf88cc8f283d9f5c8851f0a3aa07fee657b1b75ac1ad49b9c" dependencies = [ "ahash", "arrow", @@ -969,7 +972,8 @@ dependencies = [ [[package]] name = "datafusion-common-runtime" version = "41.0.0" -source = "git+https://github.com/apache/datafusion.git?rev=41.0.0-rc1#b10b820acb6ad92b5d69810e3d4de0ef6f2d6a87" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8a7496d1f664179f6ce3a5cbef6566056ccaf3ea4aa72cc455f80e62c1dd86b1" dependencies = [ "tokio", ] @@ -977,7 +981,8 @@ dependencies = [ [[package]] name = "datafusion-execution" version = "41.0.0" -source = "git+https://github.com/apache/datafusion.git?rev=41.0.0-rc1#b10b820acb6ad92b5d69810e3d4de0ef6f2d6a87" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "799e70968c815b611116951e3dd876aef04bf217da31b72eec01ee6a959336a1" dependencies = [ "arrow", "chrono", @@ -997,7 +1002,8 @@ dependencies = [ [[package]] name = "datafusion-expr" version = "41.0.0" -source = "git+https://github.com/apache/datafusion.git?rev=41.0.0-rc1#b10b820acb6ad92b5d69810e3d4de0ef6f2d6a87" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1c1841c409d9518c17971d15c9bae62e629eb937e6fb6c68cd32e9186f8b30d2" dependencies = [ "ahash", "arrow", @@ -1015,7 +1021,8 @@ dependencies = [ [[package]] name = "datafusion-functions" version = "41.0.0" -source = "git+https://github.com/apache/datafusion.git?rev=41.0.0-rc1#b10b820acb6ad92b5d69810e3d4de0ef6f2d6a87" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a8e481cf34d2a444bd8fa09b65945f0ce83dc92df8665b761505b3d9f351bebb" dependencies = [ "arrow", "arrow-buffer", @@ -1041,7 +1048,8 @@ dependencies = [ [[package]] name = "datafusion-functions-aggregate" version = "41.0.0" -source = "git+https://github.com/apache/datafusion.git?rev=41.0.0-rc1#b10b820acb6ad92b5d69810e3d4de0ef6f2d6a87" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2b4ece19f73c02727e5e8654d79cd5652de371352c1df3c4ac3e419ecd6943fb" dependencies = [ "ahash", "arrow", @@ -1058,7 +1066,8 @@ dependencies = [ [[package]] name = "datafusion-functions-nested" version = "41.0.0" -source = "git+https://github.com/apache/datafusion.git?rev=41.0.0-rc1#b10b820acb6ad92b5d69810e3d4de0ef6f2d6a87" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a1474552cc824e8c9c88177d454db5781d4b66757d4aca75719306b8343a5e8d" dependencies = [ "arrow", "arrow-array", @@ -1079,7 +1088,8 @@ dependencies = [ [[package]] name = "datafusion-optimizer" version = "41.0.0" -source = "git+https://github.com/apache/datafusion.git?rev=41.0.0-rc1#b10b820acb6ad92b5d69810e3d4de0ef6f2d6a87" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "791ff56f55608bc542d1ea7a68a64bdc86a9413f5a381d06a39fd49c2a3ab906" dependencies = [ "arrow", "async-trait", @@ -1098,7 +1108,8 @@ dependencies = [ [[package]] name = "datafusion-physical-expr" version = "41.0.0" -source = "git+https://github.com/apache/datafusion.git?rev=41.0.0-rc1#b10b820acb6ad92b5d69810e3d4de0ef6f2d6a87" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9a223962b3041304a3e20ed07a21d5de3d88d7e4e71ca192135db6d24e3365a4" dependencies = [ "ahash", "arrow", @@ -1127,7 +1138,8 @@ dependencies = [ [[package]] name = "datafusion-physical-expr-common" version = "41.0.0" -source = "git+https://github.com/apache/datafusion.git?rev=41.0.0-rc1#b10b820acb6ad92b5d69810e3d4de0ef6f2d6a87" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "db5e7d8532a1601cd916881db87a70b0a599900d23f3db2897d389032da53bc6" dependencies = [ "ahash", "arrow", @@ -1140,7 +1152,8 @@ dependencies = [ [[package]] name = "datafusion-physical-optimizer" version = "41.0.0" -source = "git+https://github.com/apache/datafusion.git?rev=41.0.0-rc1#b10b820acb6ad92b5d69810e3d4de0ef6f2d6a87" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fdb9c78f308e050f5004671039786a925c3fee83b90004e9fcfd328d7febdcc0" dependencies = [ "datafusion-common", "datafusion-execution", @@ -1151,7 +1164,8 @@ dependencies = [ [[package]] name = "datafusion-physical-plan" version = "41.0.0" -source = "git+https://github.com/apache/datafusion.git?rev=41.0.0-rc1#b10b820acb6ad92b5d69810e3d4de0ef6f2d6a87" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8d1116949432eb2d30f6362707e2846d942e491052a206f2ddcb42d08aea1ffe" dependencies = [ "ahash", "arrow", @@ -1184,7 +1198,8 @@ dependencies = [ [[package]] name = "datafusion-sql" version = "41.0.0" -source = "git+https://github.com/apache/datafusion.git?rev=41.0.0-rc1#b10b820acb6ad92b5d69810e3d4de0ef6f2d6a87" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b45d0180711165fe94015d7c4123eb3e1cf5fb60b1506453200b8d1ce666bef0" dependencies = [ "arrow", "arrow-array", diff --git a/spark/src/main/scala/org/apache/spark/Plugins.scala b/spark/src/main/scala/org/apache/spark/Plugins.scala index dcc00f66c..eb3602354 100644 --- a/spark/src/main/scala/org/apache/spark/Plugins.scala +++ b/spark/src/main/scala/org/apache/spark/Plugins.scala @@ -26,6 +26,7 @@ import org.apache.spark.api.plugin.{DriverPlugin, ExecutorPlugin, PluginContext, import org.apache.spark.comet.shims.ShimCometDriverPlugin import org.apache.spark.internal.Logging import org.apache.spark.internal.config.{EXECUTOR_MEMORY, EXECUTOR_MEMORY_OVERHEAD} +import org.apache.spark.sql.internal.StaticSQLConf import org.apache.comet.{CometConf, CometSparkSessionExtensions} @@ -44,6 +45,20 @@ class CometDriverPlugin extends DriverPlugin with Logging with ShimCometDriverPl override def init(sc: SparkContext, pluginContext: PluginContext): ju.Map[String, String] = { logInfo("CometDriverPlugin init") + // register CometSparkSessionExtensions if it isn't already registered + val extensionKey = StaticSQLConf.SPARK_SESSION_EXTENSIONS.key + val extensionClass = classOf[CometSparkSessionExtensions].getName + if (sc.conf.contains(extensionKey)) { + val extensions = sc.conf.get(extensionKey) + if (!extensions.split(",").map(_.trim).contains(extensionClass)) { + sc.conf.set(extensionKey, s"$extensions,$extensionClass") + } else { + sc.conf.set(extensionKey, extensionClass) + } + } else { + sc.conf.set(extensionKey, extensionClass) + } + if (shouldOverrideMemoryConf(sc.getConf)) { val execMemOverhead = if (sc.getConf.contains(EXECUTOR_MEMORY_OVERHEAD.key)) { sc.getConf.getSizeAsMb(EXECUTOR_MEMORY_OVERHEAD.key) From 8907196f5779b6bc780644bf50d01b767291797a Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Mon, 19 Aug 2024 17:33:44 -0600 Subject: [PATCH 2/6] update docs --- docs/source/contributor-guide/benchmarking.md | 2 +- docs/source/contributor-guide/debugging.md | 2 +- docs/source/contributor-guide/plugin_overview.md | 4 ++-- docs/source/user-guide/installation.md | 2 +- fuzz-testing/README.md | 2 +- 5 files changed, 6 insertions(+), 6 deletions(-) diff --git a/docs/source/contributor-guide/benchmarking.md b/docs/source/contributor-guide/benchmarking.md index 5f4f10912..abd35e187 100644 --- a/docs/source/contributor-guide/benchmarking.md +++ b/docs/source/contributor-guide/benchmarking.md @@ -60,7 +60,7 @@ $SPARK_HOME/bin/spark-submit \ --jars $COMET_JAR \ --conf spark.driver.extraClassPath=$COMET_JAR \ --conf spark.executor.extraClassPath=$COMET_JAR \ - --conf spark.sql.extensions=org.apache.comet.CometSparkSessionExtensions \ + org.apache.spark.CometPlugin \ --conf spark.comet.enabled=true \ --conf spark.comet.exec.enabled=true \ --conf spark.comet.exec.all.enabled=true \ diff --git a/docs/source/contributor-guide/debugging.md b/docs/source/contributor-guide/debugging.md index d1f62a5db..8aad68698 100644 --- a/docs/source/contributor-guide/debugging.md +++ b/docs/source/contributor-guide/debugging.md @@ -130,7 +130,7 @@ Then build the Comet as [described](https://github.com/apache/arrow-datafusion-c Start Comet with `RUST_BACKTRACE=1` ```console -RUST_BACKTRACE=1 $SPARK_HOME/spark-shell --jars spark/target/comet-spark-spark3.4_2.12-0.1.0-SNAPSHOT.jar --conf spark.sql.extensions=org.apache.comet.CometSparkSessionExtensions --conf spark.comet.enabled=true --conf spark.comet.exec.enabled=true --conf spark.comet.exec.all.enabled=true +RUST_BACKTRACE=1 $SPARK_HOME/spark-shell --jars spark/target/comet-spark-spark3.4_2.12-0.1.0-SNAPSHOT.jar org.apache.spark.CometPlugin --conf spark.comet.enabled=true --conf spark.comet.exec.enabled=true --conf spark.comet.exec.all.enabled=true ``` Get the expanded exception details diff --git a/docs/source/contributor-guide/plugin_overview.md b/docs/source/contributor-guide/plugin_overview.md index 8b48818ef..9e6a104b0 100644 --- a/docs/source/contributor-guide/plugin_overview.md +++ b/docs/source/contributor-guide/plugin_overview.md @@ -19,10 +19,10 @@ under the License. # Comet Plugin Overview -The entry point to Comet is the `org.apache.comet.CometSparkSessionExtensions` class, which can be registered with Spark by adding the following setting to the Spark configuration when launching `spark-shell` or `spark-submit`: +The entry point to Comet is the `org.apache.spark.CometPlugin` class, which can be registered with Spark by adding the following setting to the Spark configuration when launching `spark-shell` or `spark-submit`: ``` ---conf spark.sql.extensions=org.apache.comet.CometSparkSessionExtensions +--conf spark.plugins=org.apache.spark.CometPlugin ``` On initialization, this class registers two physical plan optimization rules with Spark: `CometScanRule` and `CometExecRule`. These rules run whenever a query stage is being planned. diff --git a/docs/source/user-guide/installation.md b/docs/source/user-guide/installation.md index cb7f032d6..b79577ee5 100644 --- a/docs/source/user-guide/installation.md +++ b/docs/source/user-guide/installation.md @@ -85,7 +85,7 @@ $SPARK_HOME/bin/spark-shell \ --jars $COMET_JAR \ --conf spark.driver.extraClassPath=$COMET_JAR \ --conf spark.executor.extraClassPath=$COMET_JAR \ - --conf spark.sql.extensions=org.apache.comet.CometSparkSessionExtensions \ + --conf spark.plugins=org.apache.spark.CometPlugin \ --conf spark.comet.enabled=true \ --conf spark.comet.exec.enabled=true \ --conf spark.comet.exec.all.enabled=true \ diff --git a/fuzz-testing/README.md b/fuzz-testing/README.md index 4b371de99..25589385c 100644 --- a/fuzz-testing/README.md +++ b/fuzz-testing/README.md @@ -86,7 +86,7 @@ Note that the output filename is currently hard-coded as `queries.sql` ```shell $SPARK_HOME/bin/spark-submit \ --master $SPARK_MASTER \ - --conf spark.sql.extensions=org.apache.comet.CometSparkSessionExtensions \ + org.apache.spark.CometPlugin \ --conf spark.comet.enabled=true \ --conf spark.comet.exec.enabled=true \ --conf spark.comet.exec.all.enabled=true \ From 7befdd0310fe91b9e0b7fe501ed8a60bd48cbd37 Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Mon, 19 Aug 2024 18:11:04 -0600 Subject: [PATCH 3/6] fix --- docs/source/contributor-guide/benchmarking.md | 2 +- docs/source/contributor-guide/debugging.md | 2 +- fuzz-testing/README.md | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/docs/source/contributor-guide/benchmarking.md b/docs/source/contributor-guide/benchmarking.md index abd35e187..1383df0a8 100644 --- a/docs/source/contributor-guide/benchmarking.md +++ b/docs/source/contributor-guide/benchmarking.md @@ -60,7 +60,7 @@ $SPARK_HOME/bin/spark-submit \ --jars $COMET_JAR \ --conf spark.driver.extraClassPath=$COMET_JAR \ --conf spark.executor.extraClassPath=$COMET_JAR \ - org.apache.spark.CometPlugin \ + --conf spark.plugins=org.apache.spark.CometPlugin \ --conf spark.comet.enabled=true \ --conf spark.comet.exec.enabled=true \ --conf spark.comet.exec.all.enabled=true \ diff --git a/docs/source/contributor-guide/debugging.md b/docs/source/contributor-guide/debugging.md index 8aad68698..832e5de98 100644 --- a/docs/source/contributor-guide/debugging.md +++ b/docs/source/contributor-guide/debugging.md @@ -130,7 +130,7 @@ Then build the Comet as [described](https://github.com/apache/arrow-datafusion-c Start Comet with `RUST_BACKTRACE=1` ```console -RUST_BACKTRACE=1 $SPARK_HOME/spark-shell --jars spark/target/comet-spark-spark3.4_2.12-0.1.0-SNAPSHOT.jar org.apache.spark.CometPlugin --conf spark.comet.enabled=true --conf spark.comet.exec.enabled=true --conf spark.comet.exec.all.enabled=true +RUST_BACKTRACE=1 $SPARK_HOME/spark-shell --jars spark/target/comet-spark-spark3.4_2.12-0.1.0-SNAPSHOT.jar --conf spark.plugins=org.apache.spark.CometPlugin --conf spark.comet.enabled=true --conf spark.comet.exec.enabled=true --conf spark.comet.exec.all.enabled=true ``` Get the expanded exception details diff --git a/fuzz-testing/README.md b/fuzz-testing/README.md index 25589385c..aa312d0b7 100644 --- a/fuzz-testing/README.md +++ b/fuzz-testing/README.md @@ -86,7 +86,7 @@ Note that the output filename is currently hard-coded as `queries.sql` ```shell $SPARK_HOME/bin/spark-submit \ --master $SPARK_MASTER \ - org.apache.spark.CometPlugin \ + --conf spark.plugins=org.apache.spark.CometPlugin \ --conf spark.comet.enabled=true \ --conf spark.comet.exec.enabled=true \ --conf spark.comet.exec.all.enabled=true \ From 1a574e6bbfbd57aaa23d54aa7394efbf999ff8dc Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Wed, 21 Aug 2024 07:51:38 -0600 Subject: [PATCH 4/6] add unit test and fix bug --- .../main/scala/org/apache/spark/Plugins.scala | 32 ++++++++++--------- .../org/apache/spark/CometPluginsSuite.scala | 31 ++++++++++++++++++ 2 files changed, 48 insertions(+), 15 deletions(-) diff --git a/spark/src/main/scala/org/apache/spark/Plugins.scala b/spark/src/main/scala/org/apache/spark/Plugins.scala index eb3602354..4f33e3aa6 100644 --- a/spark/src/main/scala/org/apache/spark/Plugins.scala +++ b/spark/src/main/scala/org/apache/spark/Plugins.scala @@ -21,13 +21,11 @@ package org.apache.spark import java.{util => ju} import java.util.Collections - import org.apache.spark.api.plugin.{DriverPlugin, ExecutorPlugin, PluginContext, SparkPlugin} import org.apache.spark.comet.shims.ShimCometDriverPlugin import org.apache.spark.internal.Logging import org.apache.spark.internal.config.{EXECUTOR_MEMORY, EXECUTOR_MEMORY_OVERHEAD} -import org.apache.spark.sql.internal.StaticSQLConf - +import org.apache.spark.sql.internal.{SQLConf, StaticSQLConf} import org.apache.comet.{CometConf, CometSparkSessionExtensions} /** @@ -46,18 +44,7 @@ class CometDriverPlugin extends DriverPlugin with Logging with ShimCometDriverPl logInfo("CometDriverPlugin init") // register CometSparkSessionExtensions if it isn't already registered - val extensionKey = StaticSQLConf.SPARK_SESSION_EXTENSIONS.key - val extensionClass = classOf[CometSparkSessionExtensions].getName - if (sc.conf.contains(extensionKey)) { - val extensions = sc.conf.get(extensionKey) - if (!extensions.split(",").map(_.trim).contains(extensionClass)) { - sc.conf.set(extensionKey, s"$extensions,$extensionClass") - } else { - sc.conf.set(extensionKey, extensionClass) - } - } else { - sc.conf.set(extensionKey, extensionClass) - } + CometDriverPlugin.registerCometSessionExtension(sc.conf) if (shouldOverrideMemoryConf(sc.getConf)) { val execMemOverhead = if (sc.getConf.contains(EXECUTOR_MEMORY_OVERHEAD.key)) { @@ -109,6 +96,21 @@ class CometDriverPlugin extends DriverPlugin with Logging with ShimCometDriverPl } } +object CometDriverPlugin { + def registerCometSessionExtension(conf: SparkConf): Unit = { + val extensionKey = StaticSQLConf.SPARK_SESSION_EXTENSIONS.key + val extensionClass = classOf[CometSparkSessionExtensions].getName + if (conf.contains(extensionKey)) { + val extensions = conf.get(extensionKey) + if (!extensions.split(",").map(_.trim).contains(extensionClass)) { + conf.set(extensionKey, s"$extensions,$extensionClass") + } + } else { + conf.set(extensionKey, extensionClass) + } + } +} + /** * The Comet plugin for Spark. To enable this plugin, set the config "spark.plugins" to * `org.apache.spark.CometPlugin` diff --git a/spark/src/test/scala/org/apache/spark/CometPluginsSuite.scala b/spark/src/test/scala/org/apache/spark/CometPluginsSuite.scala index 2263b2f99..60eeb7065 100644 --- a/spark/src/test/scala/org/apache/spark/CometPluginsSuite.scala +++ b/spark/src/test/scala/org/apache/spark/CometPluginsSuite.scala @@ -20,6 +20,7 @@ package org.apache.spark import org.apache.spark.sql.CometTestBase +import org.apache.spark.sql.internal.StaticSQLConf class CometPluginsSuite extends CometTestBase { override protected def sparkConf: SparkConf = { @@ -33,6 +34,36 @@ class CometPluginsSuite extends CometTestBase { conf } + test("Register Comet extension") { + // test common case where no extensions are previously registered + { + val conf = new SparkConf() + CometDriverPlugin.registerCometSessionExtension(conf) + assert("org.apache.comet.CometSparkSessionExtensions" == conf.get(StaticSQLConf.SPARK_SESSION_EXTENSIONS.key)) + } + // test case where Comet is already registered + { + val conf = new SparkConf() + conf.set(StaticSQLConf.SPARK_SESSION_EXTENSIONS.key, "org.apache.comet.CometSparkSessionExtensions") + CometDriverPlugin.registerCometSessionExtension(conf) + assert("org.apache.comet.CometSparkSessionExtensions" == conf.get(StaticSQLConf.SPARK_SESSION_EXTENSIONS.key)) + } + // test case where other extensions are already registered + { + val conf = new SparkConf() + conf.set(StaticSQLConf.SPARK_SESSION_EXTENSIONS.key, "foo,bar") + CometDriverPlugin.registerCometSessionExtension(conf) + assert("foo,bar,org.apache.comet.CometSparkSessionExtensions" == conf.get(StaticSQLConf.SPARK_SESSION_EXTENSIONS.key)) + } + // test case where other extensions, including Comet, are already registered + { + val conf = new SparkConf() + conf.set(StaticSQLConf.SPARK_SESSION_EXTENSIONS.key, "foo,bar,org.apache.comet.CometSparkSessionExtensions") + CometDriverPlugin.registerCometSessionExtension(conf) + assert("foo,bar,org.apache.comet.CometSparkSessionExtensions" == conf.get(StaticSQLConf.SPARK_SESSION_EXTENSIONS.key)) + } + } + test("Default Comet memory overhead") { val execMemOverhead1 = spark.conf.get("spark.executor.memoryOverhead") val execMemOverhead2 = spark.sessionState.conf.getConfString("spark.executor.memoryOverhead") From ed3e4339c92c7adf0820a3ccec24369dfed0a8da Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Wed, 21 Aug 2024 07:59:05 -0600 Subject: [PATCH 5/6] simplify logic --- spark/src/main/scala/org/apache/spark/Plugins.scala | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/spark/src/main/scala/org/apache/spark/Plugins.scala b/spark/src/main/scala/org/apache/spark/Plugins.scala index 4f33e3aa6..5f2c9515b 100644 --- a/spark/src/main/scala/org/apache/spark/Plugins.scala +++ b/spark/src/main/scala/org/apache/spark/Plugins.scala @@ -100,13 +100,14 @@ object CometDriverPlugin { def registerCometSessionExtension(conf: SparkConf): Unit = { val extensionKey = StaticSQLConf.SPARK_SESSION_EXTENSIONS.key val extensionClass = classOf[CometSparkSessionExtensions].getName - if (conf.contains(extensionKey)) { - val extensions = conf.get(extensionKey) - if (!extensions.split(",").map(_.trim).contains(extensionClass)) { + val extensions = conf.get(extensionKey, "") + if (extensions.isEmpty) { + conf.set(extensionKey, extensionClass) + } else { + val currentExtensions = extensions.split(",").map(_.trim) + if (!currentExtensions.contains(extensionClass)) { conf.set(extensionKey, s"$extensions,$extensionClass") } - } else { - conf.set(extensionKey, extensionClass) } } } From 6d87b6a5af257906d9024bf0ae91e1db396a7936 Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Thu, 22 Aug 2024 12:27:11 -0600 Subject: [PATCH 6/6] address feedback --- .../main/scala/org/apache/spark/Plugins.scala | 11 ++++++--- .../org/apache/spark/CometPluginsSuite.scala | 24 ++++++++++++++----- 2 files changed, 26 insertions(+), 9 deletions(-) diff --git a/spark/src/main/scala/org/apache/spark/Plugins.scala b/spark/src/main/scala/org/apache/spark/Plugins.scala index 5f2c9515b..ea5ef1663 100644 --- a/spark/src/main/scala/org/apache/spark/Plugins.scala +++ b/spark/src/main/scala/org/apache/spark/Plugins.scala @@ -21,11 +21,13 @@ package org.apache.spark import java.{util => ju} import java.util.Collections + import org.apache.spark.api.plugin.{DriverPlugin, ExecutorPlugin, PluginContext, SparkPlugin} import org.apache.spark.comet.shims.ShimCometDriverPlugin import org.apache.spark.internal.Logging import org.apache.spark.internal.config.{EXECUTOR_MEMORY, EXECUTOR_MEMORY_OVERHEAD} -import org.apache.spark.sql.internal.{SQLConf, StaticSQLConf} +import org.apache.spark.sql.internal.StaticSQLConf + import org.apache.comet.{CometConf, CometSparkSessionExtensions} /** @@ -96,17 +98,20 @@ class CometDriverPlugin extends DriverPlugin with Logging with ShimCometDriverPl } } -object CometDriverPlugin { +object CometDriverPlugin extends Logging { def registerCometSessionExtension(conf: SparkConf): Unit = { val extensionKey = StaticSQLConf.SPARK_SESSION_EXTENSIONS.key val extensionClass = classOf[CometSparkSessionExtensions].getName val extensions = conf.get(extensionKey, "") if (extensions.isEmpty) { + logInfo(s"Setting $extensionKey=$extensionClass") conf.set(extensionKey, extensionClass) } else { val currentExtensions = extensions.split(",").map(_.trim) if (!currentExtensions.contains(extensionClass)) { - conf.set(extensionKey, s"$extensions,$extensionClass") + val newValue = s"$extensions,$extensionClass" + logInfo(s"Setting $extensionKey=$newValue") + conf.set(extensionKey, newValue) } } } diff --git a/spark/src/test/scala/org/apache/spark/CometPluginsSuite.scala b/spark/src/test/scala/org/apache/spark/CometPluginsSuite.scala index 60eeb7065..d74142343 100644 --- a/spark/src/test/scala/org/apache/spark/CometPluginsSuite.scala +++ b/spark/src/test/scala/org/apache/spark/CometPluginsSuite.scala @@ -39,28 +39,40 @@ class CometPluginsSuite extends CometTestBase { { val conf = new SparkConf() CometDriverPlugin.registerCometSessionExtension(conf) - assert("org.apache.comet.CometSparkSessionExtensions" == conf.get(StaticSQLConf.SPARK_SESSION_EXTENSIONS.key)) + assert( + "org.apache.comet.CometSparkSessionExtensions" == conf.get( + StaticSQLConf.SPARK_SESSION_EXTENSIONS.key)) } // test case where Comet is already registered { val conf = new SparkConf() - conf.set(StaticSQLConf.SPARK_SESSION_EXTENSIONS.key, "org.apache.comet.CometSparkSessionExtensions") + conf.set( + StaticSQLConf.SPARK_SESSION_EXTENSIONS.key, + "org.apache.comet.CometSparkSessionExtensions") CometDriverPlugin.registerCometSessionExtension(conf) - assert("org.apache.comet.CometSparkSessionExtensions" == conf.get(StaticSQLConf.SPARK_SESSION_EXTENSIONS.key)) + assert( + "org.apache.comet.CometSparkSessionExtensions" == conf.get( + StaticSQLConf.SPARK_SESSION_EXTENSIONS.key)) } // test case where other extensions are already registered { val conf = new SparkConf() conf.set(StaticSQLConf.SPARK_SESSION_EXTENSIONS.key, "foo,bar") CometDriverPlugin.registerCometSessionExtension(conf) - assert("foo,bar,org.apache.comet.CometSparkSessionExtensions" == conf.get(StaticSQLConf.SPARK_SESSION_EXTENSIONS.key)) + assert( + "foo,bar,org.apache.comet.CometSparkSessionExtensions" == conf.get( + StaticSQLConf.SPARK_SESSION_EXTENSIONS.key)) } // test case where other extensions, including Comet, are already registered { val conf = new SparkConf() - conf.set(StaticSQLConf.SPARK_SESSION_EXTENSIONS.key, "foo,bar,org.apache.comet.CometSparkSessionExtensions") + conf.set( + StaticSQLConf.SPARK_SESSION_EXTENSIONS.key, + "foo,bar,org.apache.comet.CometSparkSessionExtensions") CometDriverPlugin.registerCometSessionExtension(conf) - assert("foo,bar,org.apache.comet.CometSparkSessionExtensions" == conf.get(StaticSQLConf.SPARK_SESSION_EXTENSIONS.key)) + assert( + "foo,bar,org.apache.comet.CometSparkSessionExtensions" == conf.get( + StaticSQLConf.SPARK_SESSION_EXTENSIONS.key)) } }