Skip to content

Commit

Permalink
refactor: Remove a few duplicated occurrences (#53)
Browse files Browse the repository at this point in the history
  • Loading branch information
sunchao authored Feb 20, 2024
1 parent 2820327 commit 7e206e2
Show file tree
Hide file tree
Showing 2 changed files with 41 additions and 52 deletions.
49 changes: 33 additions & 16 deletions spark/src/main/scala/org/apache/comet/CometExecIterator.scala
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,11 @@

package org.apache.comet

import java.util.HashMap

import org.apache.spark._
import org.apache.spark.sql.comet.CometMetricNode
import org.apache.spark.sql.vectorized._

import org.apache.comet.CometConf.{COMET_BATCH_SIZE, COMET_DEBUG_ENABLED, COMET_EXEC_MEMORY_FRACTION}
import org.apache.comet.vector.NativeUtil

/**
Expand All @@ -45,36 +44,31 @@ class CometExecIterator(
val id: Long,
inputs: Seq[Iterator[ColumnarBatch]],
protobufQueryPlan: Array[Byte],
configs: HashMap[String, String],
nativeMetrics: CometMetricNode)
extends Iterator[ColumnarBatch] {

private val nativeLib = new Native()
private val plan = nativeLib.createPlan(id, configs, protobufQueryPlan, nativeMetrics)
private val plan = {
val configs = createNativeConf
nativeLib.createPlan(id, configs, protobufQueryPlan, nativeMetrics)
}
private val nativeUtil = new NativeUtil
private var nextBatch: Option[ColumnarBatch] = None
private var currentBatch: ColumnarBatch = null
private var closed: Boolean = false

private def peekNext(): ExecutionState = {
val result = nativeLib.peekNext(plan)
val flag = result(0)

if (flag == 0) Pending
else if (flag == 1) {
val numRows = result(1)
val addresses = result.slice(2, result.length)
Batch(numRows = numRows.toInt, addresses = addresses)
} else {
throw new IllegalStateException(s"Invalid native flag: $flag")
}
convertNativeResult(nativeLib.peekNext(plan))
}

private def executeNative(
input: Array[Array[Long]],
finishes: Array[Boolean],
numRows: Int): ExecutionState = {
val result = nativeLib.executePlan(plan, input, finishes, numRows)
convertNativeResult(nativeLib.executePlan(plan, input, finishes, numRows))
}

private def convertNativeResult(result: Array[Long]): ExecutionState = {
val flag = result(0)
if (flag == -1) EOF
else if (flag == 0) Pending
Expand All @@ -87,6 +81,29 @@ class CometExecIterator(
}
}

/**
* Creates a new configuration map to be passed to the native side.
*/
private def createNativeConf: java.util.HashMap[String, String] = {
val result = new java.util.HashMap[String, String]()
val conf = SparkEnv.get.conf

val maxMemory = CometSparkSessionExtensions.getCometMemoryOverhead(conf)
result.put("memory_limit", String.valueOf(maxMemory))
result.put("memory_fraction", String.valueOf(COMET_EXEC_MEMORY_FRACTION.get()))
result.put("batch_size", String.valueOf(COMET_BATCH_SIZE.get()))
result.put("debug_native", String.valueOf(COMET_DEBUG_ENABLED.get()))

// Strip mandatory prefix spark. which is not required for DataFusion session params
conf.getAll.foreach {
case (k, v) if k.startsWith("spark.datafusion") =>
result.put(k.replaceFirst("spark\\.", ""), v)
case _ =>
}

result
}

/** Execution result from Comet native */
trait ExecutionState

Expand Down
44 changes: 8 additions & 36 deletions spark/src/main/scala/org/apache/spark/sql/comet/operators.scala
Original file line number Diff line number Diff line change
Expand Up @@ -23,20 +23,20 @@ import java.io.ByteArrayOutputStream

import scala.collection.mutable.ArrayBuffer

import org.apache.spark.{SparkEnv, TaskContext}
import org.apache.spark.TaskContext
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeSet, Expression, NamedExpression, SortOrder}
import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, AggregateMode}
import org.apache.spark.sql.catalyst.plans.physical.Partitioning
import org.apache.spark.sql.execution.{ColumnarToRowExec, ExecSubqueryExpression, ExplainUtils, LeafExecNode, ScalarSubquery, SparkPlan, UnaryExecNode}
import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics}
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.vectorized.ColumnarBatch

import com.google.common.base.Objects

import org.apache.comet.{CometConf, CometExecIterator, CometRuntimeException, CometSparkSessionExtensions}
import org.apache.comet.CometConf.{COMET_BATCH_SIZE, COMET_DEBUG_ENABLED, COMET_EXEC_MEMORY_FRACTION}
import org.apache.comet.{CometConf, CometExecIterator, CometRuntimeException}
import org.apache.comet.serde.OperatorOuterClass.Operator

/**
Expand Down Expand Up @@ -83,17 +83,7 @@ object CometExec {
nativePlan.writeTo(outputStream)
outputStream.close()
val bytes = outputStream.toByteArray

val configs = new java.util.HashMap[String, String]()

val maxMemory =
CometSparkSessionExtensions.getCometMemoryOverhead(SparkEnv.get.conf)
configs.put("memory_limit", String.valueOf(maxMemory))
configs.put("memory_fraction", String.valueOf(COMET_EXEC_MEMORY_FRACTION.get()))
configs.put("batch_size", String.valueOf(COMET_BATCH_SIZE.get()))
configs.put("debug_native", String.valueOf(COMET_DEBUG_ENABLED.get()))

new CometExecIterator(newIterId, inputs, bytes, configs, nativeMetrics)
new CometExecIterator(newIterId, inputs, bytes, nativeMetrics)
}
}

Expand Down Expand Up @@ -163,33 +153,15 @@ abstract class CometNativeExec extends CometExec {
case Some(serializedPlan) =>
// Switch to use Decimal128 regardless of precision, since Arrow native execution
// doesn't support Decimal32 and Decimal64 yet.
conf.setConfString(CometConf.COMET_USE_DECIMAL_128.key, "true")

// Populate native configurations
val configs = new java.util.HashMap[String, String]()
val maxMemory = CometSparkSessionExtensions.getCometMemoryOverhead(sparkContext.getConf)
configs.put("memory_limit", String.valueOf(maxMemory))
configs.put("memory_fraction", String.valueOf(COMET_EXEC_MEMORY_FRACTION.get()))
configs.put("batch_size", String.valueOf(COMET_BATCH_SIZE.get()))
configs.put("debug_native", String.valueOf(COMET_DEBUG_ENABLED.get()))

// Strip mandatory prefix spark. which is not required for datafusion session params
session.conf.getAll.foreach {
case (k, v) if k.startsWith("spark.datafusion") =>
configs.put(k.replaceFirst("spark\\.", ""), v)
case _ =>
}
SQLConf.get.setConfString(CometConf.COMET_USE_DECIMAL_128.key, "true")

val serializedPlanCopy = serializedPlan
// TODO: support native metrics for all operators.
val nativeMetrics = CometMetricNode.fromCometPlan(this)

def createCometExecIter(inputs: Seq[Iterator[ColumnarBatch]]): CometExecIterator = {
val it = new CometExecIterator(
CometExec.newIterId,
inputs,
serializedPlanCopy,
configs,
nativeMetrics)
val it =
new CometExecIterator(CometExec.newIterId, inputs, serializedPlanCopy, nativeMetrics)

setSubqueries(it.id, originalPlan)

Expand Down

0 comments on commit 7e206e2

Please sign in to comment.