Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-50372][CONNECT][SQL] Make all DF execution path collect observed metrics #48920

Open
wants to merge 4 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ import org.apache.spark.annotation.{DeveloperApi, Experimental, Since}
import org.apache.spark.api.java.JavaRDD
import org.apache.spark.connect.proto
import org.apache.spark.connect.proto.ExecutePlanResponse
import org.apache.spark.connect.proto.ExecutePlanResponse.ObservedMetrics
import org.apache.spark.internal.Logging
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalog.Catalog
Expand Down Expand Up @@ -385,13 +386,8 @@ class SparkSession private[sql] (
private[sql] def timeZoneId: String = conf.get(SqlApiConf.SESSION_LOCAL_TIMEZONE_KEY)

private[sql] def execute[T](plan: proto.Plan, encoder: AgnosticEncoder[T]): SparkResult[T] = {
val value = client.execute(plan)
new SparkResult(
value,
allocator,
encoder,
timeZoneId,
Some(setMetricsAndUnregisterObservation))
val value = executeInternal(plan)
new SparkResult(value, allocator, encoder, timeZoneId)
}

private[sql] def execute(f: proto.Relation.Builder => Unit): Unit = {
Expand All @@ -400,7 +396,7 @@ class SparkSession private[sql] (
builder.getCommonBuilder.setPlanId(planIdGenerator.getAndIncrement())
val plan = proto.Plan.newBuilder().setRoot(builder).build()
// .foreach forces that the iterator is consumed and closed
client.execute(plan).foreach(_ => ())
executeInternal(plan).foreach(_ => ())
}

@Since("4.0.0")
Expand All @@ -409,11 +405,26 @@ class SparkSession private[sql] (
val plan = proto.Plan.newBuilder().setCommand(command).build()
// .toSeq forces that the iterator is consumed and closed. On top, ignore all
// progress messages.
client.execute(plan).filter(!_.hasExecutionProgress).toSeq
executeInternal(plan).filter(!_.hasExecutionProgress).toSeq
}

private[sql] def execute(plan: proto.Plan): CloseableIterator[ExecutePlanResponse] =
client.execute(plan)
/**
* The real `execute` method that calls into `SparkConnectClient`.
*
* Here we inject a lazy map to process registered observed metrics, so consumers of the
* returned iterator does not need to worry about it.
*
* Please make sure all `execute` methods call this method.
*/
private[sql] def executeInternal(plan: proto.Plan): CloseableIterator[ExecutePlanResponse] = {
client
.execute(plan)
.map { response =>
// Note, this map() is lazy.
processRegisteredObservedMetrics(response.getObservedMetricsList)
response
}
}

private[sql] def registerUdf(udf: proto.CommonInlineUserDefinedFunction): Unit = {
val command = proto.Command.newBuilder().setRegisterFunction(udf).build()
Expand Down Expand Up @@ -555,10 +566,14 @@ class SparkSession private[sql] (
observationRegistry.putIfAbsent(planId, observation)
}

private[sql] def setMetricsAndUnregisterObservation(planId: Long, metrics: Row): Unit = {
val observationOrNull = observationRegistry.remove(planId)
if (observationOrNull != null) {
observationOrNull.setMetricsAndNotify(metrics)
private def processRegisteredObservedMetrics(metrics: java.util.List[ObservedMetrics]): Unit = {
metrics.asScala.map { metric =>
// Here we only process metrics that belong to a registered Observation object.
// All metrics, whether registered or not, will be collected by `SparkResult`.
val observationOrNull = observationRegistry.remove(metric.getPlanId)
if (observationOrNull != null) {
observationOrNull.setMetricsAndNotify(SparkResult.transformObservedMetrics(metric))
}
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1536,28 +1536,49 @@ class ClientE2ETestSuite
val ob1Metrics = Map("ob1" -> new GenericRowWithSchema(Array(0, 49, 98), ob1Schema))
val ob2Metrics = Map("ob2" -> new GenericRowWithSchema(Array(-1, 48, 97), ob2Schema))

val obMetrics = observedDf.collectResult().getObservedMetrics
assert(df.collectResult().getObservedMetrics === Map.empty)
assert(observedDf.collectResult().getObservedMetrics === ob1Metrics)
assert(observedObservedDf.collectResult().getObservedMetrics === ob1Metrics ++ ob2Metrics)
}

test("Observation.get is blocked until the query is finished") {
val df = spark.range(99).withColumn("extra", col("id") - 1)
val observation = new Observation("ob1")
val observedDf = df.observe(observation, min("id"), avg("id"), max("id"))

// Start a new thread to get the observation
val future = Future(observation.get)(ExecutionContext.global)
// make sure the thread is blocked right now
val e = intercept[java.util.concurrent.TimeoutException] {
SparkThreadUtils.awaitResult(future, 2.seconds)
assert(obMetrics.map(_._2.schema) === Seq(ob1Schema))

val obObMetrics = observedObservedDf.collectResult().getObservedMetrics
assert(obObMetrics === ob1Metrics ++ ob2Metrics)
assert(obObMetrics.map(_._2.schema).exists(_.equals(ob1Schema)))
assert(obObMetrics.map(_._2.schema).exists(_.equals(ob2Schema)))
}

for (collectFunc <- Seq(
("collect", (df: DataFrame) => df.collect()),
("collectAsList", (df: DataFrame) => df.collectAsList()),
("collectResult", (df: DataFrame) => df.collectResult().length),
("write", (df: DataFrame) => df.write.format("noop").mode("append").save())))
test(
"Observation.get is blocked until the query is finished, " +
s"collect using method ${collectFunc._1}") {
val df = spark.range(99).withColumn("extra", col("id") - 1)
val ob1 = new Observation("ob1")
val ob2 = new Observation("ob2")
val observedDf = df.observe(ob1, min("id"), avg("id"), max("id"))
val observedObservedDf = observedDf.observe(ob2, min("extra"), avg("extra"), max("extra"))
// Start new threads to get observations
val future1 = Future(ob1.get)(ExecutionContext.global)
val future2 = Future(ob2.get)(ExecutionContext.global)
// make sure the threads are blocked right now
val e1 = intercept[java.util.concurrent.TimeoutException] {
SparkThreadUtils.awaitResult(future1, 2.seconds)
}
assert(e1.getMessage.contains("timed out after"))
val e2 = intercept[java.util.concurrent.TimeoutException] {
SparkThreadUtils.awaitResult(future2, 2.seconds)
}
assert(e2.getMessage.contains("timed out after"))
collectFunc._2(observedObservedDf)
// make sure the threads are unblocked after the query is finished
val metrics1 = SparkThreadUtils.awaitResult(future1, 5.seconds)
assert(metrics1 === Map("min(id)" -> 0, "avg(id)" -> 49, "max(id)" -> 98))
val metrics2 = SparkThreadUtils.awaitResult(future2, 5.seconds)
assert(metrics2 === Map("min(extra)" -> -1, "avg(extra)" -> 48, "max(extra)" -> 97))
}
assert(e.getMessage.contains("Future timed out"))
observedDf.collect()
// make sure the thread is unblocked after the query is finished
val metrics = SparkThreadUtils.awaitResult(future, 2.seconds)
assert(metrics === Map("min(id)" -> 0, "avg(id)" -> 49, "max(id)" -> 98))
}

test("SPARK-48852: trim function on a string column returns correct results") {
val session: SparkSession = spark
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,16 @@ private[sql] trait CloseableIterator[E] extends Iterator[E] with AutoCloseable {

override def close() = self.close()
}

override def map[B](f: E => B): CloseableIterator[B] = {
new CloseableIterator[B] {
override def next(): B = f(self.next())

override def hasNext: Boolean = self.hasNext

override def close(): Unit = self.close()
}
}
}

private[sql] abstract class WrappedCloseableIterator[E] extends CloseableIterator[E] {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,8 +41,7 @@ private[sql] class SparkResult[T](
responses: CloseableIterator[proto.ExecutePlanResponse],
allocator: BufferAllocator,
encoder: AgnosticEncoder[T],
timeZoneId: String,
setObservationMetricsOpt: Option[(Long, Row) => Unit] = None)
timeZoneId: String)
extends AutoCloseable { self =>

case class StageInfo(
Expand Down Expand Up @@ -122,7 +121,8 @@ private[sql] class SparkResult[T](
while (!stop && responses.hasNext) {
val response = responses.next()

// Collect metrics for this response
// Collect **all** metrics for this response, whether or not registered to an Observation
// object.
observedMetrics ++= processObservedMetrics(response.getObservedMetricsList)

// Save and validate operationId
Expand Down Expand Up @@ -209,23 +209,7 @@ private[sql] class SparkResult[T](
private def processObservedMetrics(
metrics: java.util.List[ObservedMetrics]): Iterable[(String, Row)] = {
metrics.asScala.map { metric =>
assert(metric.getKeysCount == metric.getValuesCount)
var schema = new StructType()
val values = mutable.ArrayBuilder.make[Any]
values.sizeHint(metric.getKeysCount)
(0 until metric.getKeysCount).foreach { i =>
val key = metric.getKeys(i)
val value = LiteralValueProtoConverter.toCatalystValue(metric.getValues(i))
schema = schema.add(key, LiteralValueProtoConverter.toDataType(value.getClass))
values += value
}
val row = new GenericRowWithSchema(values.result(), schema)
// If the metrics is registered by an Observation object, attach them and unblock any
// blocked thread.
setObservationMetricsOpt.foreach { setObservationMetrics =>
setObservationMetrics(metric.getPlanId, row)
}
metric.getName -> row
metric.getName -> SparkResult.transformObservedMetrics(metric)
}
}

Expand Down Expand Up @@ -387,8 +371,23 @@ private[sql] class SparkResult[T](
}
}

private object SparkResult {
private[sql] object SparkResult {
private val cleaner: Cleaner = Cleaner.create()

/** Return value is a Seq of pairs, to preserve the order of values. */
private[sql] def transformObservedMetrics(metric: ObservedMetrics): Row = {
assert(metric.getKeysCount == metric.getValuesCount)
var schema = new StructType()
val values = mutable.ArrayBuilder.make[Any]
values.sizeHint(metric.getKeysCount)
(0 until metric.getKeysCount).foreach { i =>
val key = metric.getKeys(i)
val value = LiteralValueProtoConverter.toCatalystValue(metric.getValues(i))
schema = schema.add(key, LiteralValueProtoConverter.toDataType(value.getClass))
values += value
}
new GenericRowWithSchema(values.result(), schema)
}
}

private[client] class SparkResultCloseable(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -245,7 +245,7 @@ private[connect] class ExecuteThreadRunner(executeHolder: ExecuteHolder) extends
.createObservedMetricsResponse(
executeHolder.sessionHolder.sessionId,
executeHolder.sessionHolder.serverSessionId,
executeHolder.request.getPlan.getRoot.getCommon.getPlanId,
executeHolder.allObservationAndPlanIds,
observedMetrics ++ accumulatedInPython))
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -77,8 +77,10 @@ private[execution] class SparkConnectPlanExecution(executeHolder: ExecuteHolder)
responseObserver.onNext(createSchemaResponse(request.getSessionId, dataframe.schema))
processAsArrowBatches(dataframe, responseObserver, executeHolder)
responseObserver.onNext(MetricGenerator.createMetricsResponse(sessionHolder, dataframe))
createObservedMetricsResponse(request.getSessionId, dataframe).foreach(
responseObserver.onNext)
createObservedMetricsResponse(
request.getSessionId,
executeHolder.allObservationAndPlanIds,
dataframe).foreach(responseObserver.onNext)
}

type Batch = (Array[Byte], Long)
Expand Down Expand Up @@ -255,6 +257,7 @@ private[execution] class SparkConnectPlanExecution(executeHolder: ExecuteHolder)

private def createObservedMetricsResponse(
sessionId: String,
observationAndPlanIds: Map[String, Long],
dataframe: DataFrame): Option[ExecutePlanResponse] = {
val observedMetrics = dataframe.queryExecution.observedMetrics.collect {
case (name, row) if !executeHolder.observations.contains(name) =>
Expand All @@ -264,13 +267,12 @@ private[execution] class SparkConnectPlanExecution(executeHolder: ExecuteHolder)
name -> values
}
if (observedMetrics.nonEmpty) {
val planId = executeHolder.request.getPlan.getRoot.getCommon.getPlanId
Some(
SparkConnectPlanExecution
.createObservedMetricsResponse(
sessionId,
sessionHolder.serverSessionId,
planId,
observationAndPlanIds,
observedMetrics))
} else None
}
Expand All @@ -280,17 +282,17 @@ object SparkConnectPlanExecution {
def createObservedMetricsResponse(
sessionId: String,
serverSessionId: String,
planId: Long,
observationAndPlanIds: Map[String, Long],
metrics: Map[String, Seq[(Option[String], Any)]]): ExecutePlanResponse = {
val observedMetrics = metrics.map { case (name, values) =>
val metrics = ExecutePlanResponse.ObservedMetrics
.newBuilder()
.setName(name)
.setPlanId(planId)
values.foreach { case (key, value) =>
metrics.addValues(toLiteralProto(value))
key.foreach(metrics.addKeys)
}
observationAndPlanIds.get(name).foreach(metrics.setPlanId)
metrics.build()
}
// Prepare a response with the observed metrics.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1190,14 +1190,14 @@ class SparkConnectPlanner(
val input = transformRelation(rel.getInput)

if (input.isStreaming || executeHolderOpt.isEmpty) {
CollectMetrics(name, metrics.map(_.named), transformRelation(rel.getInput), planId)
CollectMetrics(name, metrics.map(_.named), input, planId)
Copy link
Contributor Author

@xupefei xupefei Nov 21, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This fixes a bug where the input of a CollectMetrics can be processed two times, once in Line 1190 and once here/below.

When the input contains another CollectMetrics, transforming it twice will cause two Observation objects (in the input) to be initialised and registered two times to the system. Since only one of them will be fulfilled when the query finishes, the one we'll be looking at may not have any data.

This issue is highlighted in the test case Observation.get is blocked until the query is finished ..., where we specifically execute observedObservedDf, which is a CollectMetrics that has another CollectMetrics as its input.

} else {
// TODO this might be too complex for no good reason. It might
// be easier to inspect the plan after it completes.
val observation = Observation(name)
session.observationManager.register(observation, planId)
executeHolderOpt.get.addObservation(name, observation)
CollectMetrics(name, metrics.map(_.named), transformRelation(rel.getInput), planId)
CollectMetrics(name, metrics.map(_.named), input, planId)
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@ import java.util.concurrent.atomic.AtomicBoolean
import scala.collection.mutable
import scala.jdk.CollectionConverters._

import com.google.protobuf.GeneratedMessage

import org.apache.spark.SparkEnv
import org.apache.spark.connect.proto
import org.apache.spark.internal.Logging
Expand Down Expand Up @@ -81,6 +83,10 @@ private[connect] class ExecuteHolder(

val observations: mutable.Map[String, Observation] = mutable.Map.empty

lazy val allObservationAndPlanIds: Map[String, Long] = {
ExecuteHolder.collectAllObservationAndPlanIds(request.getPlan).toMap
}

private val runner: ExecuteThreadRunner = new ExecuteThreadRunner(this)

/** System.currentTimeMillis when this ExecuteHolder was created. */
Expand Down Expand Up @@ -289,6 +295,26 @@ private[connect] class ExecuteHolder(
def operationId: String = key.operationId
}

private object ExecuteHolder {
private def collectAllObservationAndPlanIds(
planOrMessage: GeneratedMessage,
collected: mutable.Map[String, Long] = mutable.Map.empty): mutable.Map[String, Long] = {
planOrMessage match {
case relation: proto.Relation if relation.hasCollectMetrics =>
collected += relation.getCollectMetrics.getName -> relation.getCommon.getPlanId
collectAllObservationAndPlanIds(relation.getCollectMetrics.getInput, collected)
case _ =>
planOrMessage.getAllFields.values().asScala.foreach {
case message: GeneratedMessage =>
collectAllObservationAndPlanIds(message, collected)
case _ =>
// not a message (probably a primitive type), do nothing
}
}
collected
}
}

/** Used to identify ExecuteHolder jobTag among SparkContext.SPARK_JOB_TAGS. */
object ExecuteJobTag {
private val prefix = "SparkConnect_OperationTag"
Expand Down