Skip to content

Commit

Permalink
Merge pull request #14 from lsds/mapPipeline
Browse files Browse the repository at this point in the history
Added support for operator pipelining
  • Loading branch information
pgaref authored Jun 26, 2019
2 parents 805dad1 + fbf4a63 commit 8f51466
Show file tree
Hide file tree
Showing 9 changed files with 193 additions and 45 deletions.
15 changes: 10 additions & 5 deletions core/src/main/scala/org/apache/spark/api/sgx/SGXBaseRunner.scala
Original file line number Diff line number Diff line change
Expand Up @@ -29,11 +29,13 @@ import org.apache.spark._
import org.apache.spark.internal.Logging
import org.apache.spark.util.Utils

import scala.reflect.{ClassTag}
import scala.collection.mutable.ArrayBuffer
import scala.reflect.ClassTag

private[spark] abstract class SGXBaseRunner[IN: ClassTag, OUT: ClassTag](
func: (Iterator[Any]) => Any,
evalType: Int) extends Logging {
evalType: Int,
funcs: ArrayBuffer[(Iterator[Any]) => Any]) extends Logging {

private val conf = SparkEnv.get.conf
private val bufferSize = conf.getInt("spark.buffer.size", 65536)
Expand Down Expand Up @@ -316,6 +318,9 @@ private[spark] abstract class SGXBaseRunner[IN: ClassTag, OUT: ClassTag](

private[spark] object SGXFunctionType {
val NON_UDF = 0
val PIPELINED = 1
val SHUFFLE_MAP = 2
val SHUFFLE_REDUCE = 3
val BATCHED_UDF = 100
def toString(sgxFuncType: Int): String = sgxFuncType match {
case NON_UDF => "NON_UDF"
Expand All @@ -324,11 +329,11 @@ private[spark] object SGXFunctionType {
}

private[spark] object SpecialSGXChars {
val EMPTY_DATA = 0
val END_OF_FUNC_SECTION = 0
val END_OF_DATA_SECTION = -1
val SGX_EXCEPTION_THROWN = -2
val TIMING_DATA = -3
val END_OF_STREAM = -4
val NULL = -5
val START_ARROW_STREAM = -6
val EMPTY_DATA = -5
val NULL = -6
}
2 changes: 1 addition & 1 deletion core/src/main/scala/org/apache/spark/api/sgx/SGXRDD.scala
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ private[spark] object SGXRDD extends Logging {
case _ =>
val outSerArray = serializer.serialize(obj).array()
dataOut.writeInt(outSerArray.length)
logDebug(s"SGX => Writing: ${outSerArray}")
logDebug(s"SGX => Writing: ${obj}")
dataOut.write(outSerArray)
// throw new SparkException("Unexpected element type " + other.getClass)
}
Expand Down
25 changes: 22 additions & 3 deletions core/src/main/scala/org/apache/spark/api/sgx/SGXRunner.scala
Original file line number Diff line number Diff line change
Expand Up @@ -22,15 +22,25 @@ import java.net.Socket

import org.apache.spark._

import scala.collection.mutable.ArrayBuffer

private[spark] object SGXRunner {
def apply(func: (Iterator[Any]) => Any): SGXRunner = {
new SGXRunner(func)
new SGXRunner(func, SGXFunctionType.NON_UDF, ArrayBuffer.empty)
}

def apply(func: (Iterator[Any]) => Any, funcType: Int): SGXRunner = {
new SGXRunner(func, funcType, ArrayBuffer.empty)
}

def apply(func: (Iterator[Any]) => Any, funcType: Int, funcs: ArrayBuffer[(Iterator[Any]) => Any]): SGXRunner = {
new SGXRunner(func, funcType, funcs)
}
}

/** Helper class to run a function in SGX Spark */
private[spark] class SGXRunner(func: (Iterator[Any]) => Any) extends
SGXBaseRunner[Array[Byte], Array[Byte]](func, SGXFunctionType.NON_UDF) {
private[spark] class SGXRunner(func: (Iterator[Any]) => Any, funcType: Int, funcs: ArrayBuffer[(Iterator[Any]) => Any])
extends SGXBaseRunner[Array[Byte], Array[Byte]](func, funcType, funcs) {

override protected def sgxWriterThread(env: SparkEnv,
worker: Socket,
Expand All @@ -39,9 +49,18 @@ private[spark] class SGXRunner(func: (Iterator[Any]) => Any) extends
new WriterIterator(env, worker, inputIterator, partitionIndex, context) {
/** Writes a command section to the stream connected to the SGX worker */
override protected def writeFunction(dataOut: DataOutputStream): Unit = {
logInfo(s"Ser ${funcs.size + 1} closures")
for (currFunc <- funcs) {
logDebug(s"Ser closure: ${currFunc.getClass}")
val command = closureSer.serialize(currFunc)
dataOut.writeInt(command.array().size)
dataOut.write(command.array())
}
val command = closureSer.serialize(func)
logDebug(s"Ser func: ${func.getClass}")
dataOut.writeInt(command.array().size)
dataOut.write(command.array())
dataOut.writeInt(SpecialSGXChars.END_OF_FUNC_SECTION)
dataOut.flush()
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,18 +21,18 @@ import java.io.{DataInputStream, DataOutputStream, IOException}
import java.net.{InetAddress, Socket}
import java.nio.ByteBuffer

import org.apache.spark.api.sgx.{SGXException, SGXRDD, SpecialSGXChars}
import org.apache.spark.api.sgx.{SGXException, SGXFunctionType, SGXRDD, SpecialSGXChars}
import org.apache.spark.internal.Logging
import org.apache.spark.serializer._
import org.apache.spark.util.Utils
import org.apache.spark.{SparkConf, SparkEnv, SparkException, TaskContext}
import org.apache.spark.{SparkConf, SparkException, TaskContext}

import scala.collection.mutable
import scala.reflect.ClassTag

private[spark] class SGXWorker(closuseSer: SerializerInstance, dataSer: SerializerInstance) extends Logging {
val SYSTEM_NAME = "sparkSGXWorker"
val ENDPOINT_NAME = "SGXWorker"
val SYSTEM_NAME = "SparkSGXWorker"
val ENDPOINT_NAME = "SecureWorker"
if (dataSer == null || closuseSer == null) {
throw new SGXException("Worker Serializer not set", new RuntimeException)
}
Expand Down Expand Up @@ -65,23 +65,38 @@ private[spark] class SGXWorker(closuseSer: SerializerInstance, dataSer: Serializ
val spark_files_dir = SGXRDD.readUTF(inSock)

// Read Function Type & Function
val eval_type = inSock.readInt()
val func = readFunction(inSock)

val init_time = System.nanoTime()
val eval_type = inSock.readInt()

// Read Iterator
val iterator = new ReaderIterator(inSock, dataSer)
val res = func(iterator)

SGXRDD.writeIteratorToStream[Any](res.asInstanceOf[Iterator[Any]], dataSer, outSock)
outSock.writeInt(SpecialSGXChars.END_OF_DATA_SECTION)
outSock.flush()
val funcArray: mutable.ArrayBuffer[(Iterator[Any]) => Any] = readFunction(inSock)
logInfo(s"Executing ${funcArray.size} (pipelined) funcs")

eval_type match {
case SGXFunctionType.NON_UDF =>
// Read Iterator
val iterator = new ReaderIterator(inSock, dataSer)
val res = funcArray.head(iterator)
SGXRDD.writeIteratorToStream[Any](res.asInstanceOf[Iterator[Any]], dataSer, outSock)
outSock.writeInt(SpecialSGXChars.END_OF_DATA_SECTION)
outSock.flush()
case SGXFunctionType.PIPELINED =>
val iterator = new ReaderIterator(inSock, dataSer)

var res: Iterator[Any] = null
for (func <- funcArray) {
logDebug(s"Running Func ${func.getClass}")
res = if (res == null) func(iterator).asInstanceOf[Iterator[Any]] else func(res).asInstanceOf[Iterator[Any]]
}
SGXRDD.writeIteratorToStream[Any](res, dataSer, outSock)
outSock.writeInt(SpecialSGXChars.END_OF_DATA_SECTION)
outSock.flush()
case _ =>
logError(s"Unsupported FunctionType ${eval_type}")
}

val finishTime = System.nanoTime()

// Write reportTimes AND Shuffle timestamps

outSock.writeInt(SpecialSGXChars.END_OF_STREAM)
outSock.flush()
// send metrics etc
Expand All @@ -99,11 +114,22 @@ private[spark] class SGXWorker(closuseSer: SerializerInstance, dataSer: Serializ
outfile.writeLong(finishTime)
}

def readFunction(inSock: DataInputStream): (Iterator[Any]) => Any = {
val func_size = inSock.readInt()
val obj = new Array[Byte](func_size)
inSock.readFully(obj)
closuseSer.deserialize[(Iterator[Any]) => Any](ByteBuffer.wrap(obj))
def readFunction(inSock: DataInputStream): mutable.ArrayBuffer[(Iterator[Any]) => Any] = {
val functionArr = mutable.ArrayBuffer[(Iterator[Any]) => Any]()
var done = false
while (!done) {
inSock.readInt() match {
case func_size if func_size > 0 =>
val obj = new Array[Byte](func_size)
inSock.readFully(obj)
val closure = closuseSer.deserialize[(Iterator[Any]) => Any](ByteBuffer.wrap(obj))
functionArr.append(closure)
case SpecialSGXChars.END_OF_FUNC_SECTION =>
logDebug(s"Read ${functionArr.size} functions Done")
done = true
}
}
functionArr
}
}

Expand Down
23 changes: 21 additions & 2 deletions core/src/main/scala/org/apache/spark/rdd/MapPartitionsRDD.scala
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ package org.apache.spark.rdd

import scala.reflect.ClassTag

import org.apache.spark.SparkEnv
import org.apache.spark.{Partition, TaskContext}

/**
Expand All @@ -39,6 +40,7 @@ import org.apache.spark.{Partition, TaskContext}
private[spark] class MapPartitionsRDD[U: ClassTag, T: ClassTag](
var prev: RDD[T],
f: (TaskContext, Int, Iterator[T]) => Iterator[U], // (TaskContext, partition index, iterator)
cleanFunc: (Iterator[Any]) => Any = null,
preservesPartitioning: Boolean = false,
isFromBarrier: Boolean = false,
isOrderSensitive: Boolean = false)
Expand All @@ -48,8 +50,25 @@ private[spark] class MapPartitionsRDD[U: ClassTag, T: ClassTag](

override def getPartitions: Array[Partition] = firstParent[T].partitions

override def compute(split: Partition, context: TaskContext): Iterator[U] =
f(context, split.index, firstParent[T].iterator(split, context))
override def compute(split: Partition, context: TaskContext): Iterator[U] = {
// SGX - ShuffleAggregation is actually performed here
val toRet = if (!SparkEnv.get.conf.isSGXWorkerEnabled()) {
f(context, split.index, firstParent[T].iterator(split, context))
} else {
// Trigger iterator pipeline and gather closures
firstParent[T].iterator(split, context)
for (parFunc <- firstParent[T].funcBuff) {
if (!funcBuff.contains(parFunc)) {
funcBuff.append(parFunc)
}
}
// TODO: support mapPartitionsWithIndex Case?
assert(cleanFunc != null)
if (!funcBuff.contains(cleanFunc)) funcBuff.append(cleanFunc)
firstParent[T].iterator(split, context).asInstanceOf[Iterator[U]]
}
toRet
}

override def clearDependencies() {
super.clearDependencies()
Expand Down
27 changes: 20 additions & 7 deletions core/src/main/scala/org/apache/spark/rdd/RDD.scala
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,8 @@ abstract class RDD[T: ClassTag](
_sc
}

private[spark] val funcBuff: ArrayBuffer[(Iterator[Any]) => Any] = mutable.ArrayBuffer[(Iterator[Any]) => Any]()

/** Construct an RDD with just a one-to-one dependency on one parent */
def this(@transient oneParent: RDD[_]) =
this(oneParent.context, List(new OneToOneDependency(oneParent)))
Expand Down Expand Up @@ -369,8 +371,11 @@ abstract class RDD[T: ClassTag](
* Return a new RDD by applying a function to all elements of this RDD.
*/
def map[U: ClassTag](f: T => U): RDD[U] = withScope {
val cleanF = sc.clean(f)
new MapPartitionsRDD[U, T](this, (context, pid, iter) => iter.map(cleanF))
val cleanF = sc.clean[Any => Any](f.asInstanceOf[Any => Any])
new MapPartitionsRDD[U, T](
this,
f = (context, pid, iter) => iter.map(cleanF.asInstanceOf[T => U]),
cleanFunc = (iter: Iterator[Any]) => iter.map(cleanF))
}

/**
Expand All @@ -379,17 +384,20 @@ abstract class RDD[T: ClassTag](
*/
def flatMap[U: ClassTag](f: T => TraversableOnce[U]): RDD[U] = withScope {
val cleanF = sc.clean(f)
new MapPartitionsRDD[U, T](this, (context, pid, iter) => iter.flatMap(cleanF))
new MapPartitionsRDD[U, T](this,
(context, pid, iter) => iter.flatMap(cleanF),
cleanFunc = (iter: Iterator[Any]) => iter.flatMap(cleanF.asInstanceOf[Any => Iterator[Any]]))
}

/**
* Return a new RDD containing only the elements that satisfy a predicate.
*/
def filter(f: T => Boolean): RDD[T] = withScope {
val cleanF = sc.clean(f)
val cleanF = sc.clean[Any => Boolean](f.asInstanceOf[Any => Boolean])
new MapPartitionsRDD[T, T](
this,
(context, pid, iter) => iter.filter(cleanF),
cleanFunc = (iter: Iterator[Any]) => iter.filter(cleanF),
preservesPartitioning = true)
}

Expand Down Expand Up @@ -800,7 +808,8 @@ abstract class RDD[T: ClassTag](
new MapPartitionsRDD(
this,
(context: TaskContext, index: Int, iter: Iterator[T]) => cleanedF(iter),
preservesPartitioning)
cleanFunc = cleanedF.asInstanceOf[(Iterator[Any]) => Iterator[Any]],
preservesPartitioning = preservesPartitioning)
}

/**
Expand All @@ -819,9 +828,11 @@ abstract class RDD[T: ClassTag](
f: (Int, Iterator[T]) => Iterator[U],
preservesPartitioning: Boolean = false,
isOrderSensitive: Boolean = false): RDD[U] = withScope {
val cleanedF = sc.clean(f)
new MapPartitionsRDD(
this,
(context: TaskContext, index: Int, iter: Iterator[T]) => f(index, iter),
cleanFunc = cleanedF.asInstanceOf[(Iterator[Any]) => Iterator[Any]],
preservesPartitioning = preservesPartitioning,
isOrderSensitive = isOrderSensitive)
}
Expand All @@ -832,10 +843,12 @@ abstract class RDD[T: ClassTag](
private[spark] def mapPartitionsInternal[U: ClassTag](
f: Iterator[T] => Iterator[U],
preservesPartitioning: Boolean = false): RDD[U] = withScope {
val cleanedF = sc.clean(f)
new MapPartitionsRDD(
this,
(context: TaskContext, index: Int, iter: Iterator[T]) => f(iter),
preservesPartitioning)
cleanFunc = cleanedF.asInstanceOf[(Iterator[Any]) => Iterator[Any]],
preservesPartitioning = preservesPartitioning)
}

/**
Expand All @@ -852,7 +865,7 @@ abstract class RDD[T: ClassTag](
new MapPartitionsRDD(
this,
(context: TaskContext, index: Int, iter: Iterator[T]) => cleanedF(index, iter),
preservesPartitioning)
preservesPartitioning = preservesPartitioning)
}

/**
Expand Down
2 changes: 1 addition & 1 deletion core/src/main/scala/org/apache/spark/rdd/RDDBarrier.scala
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ class RDDBarrier[T: ClassTag] private[spark] (rdd: RDD[T]) {
new MapPartitionsRDD(
rdd,
(context: TaskContext, index: Int, iter: Iterator[T]) => cleanedF(iter),
preservesPartitioning,
preservesPartitioning = preservesPartitioning,
isFromBarrier = true
)
}
Expand Down
35 changes: 35 additions & 0 deletions core/src/main/scala/org/apache/spark/util/SGXUtils.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.util


object SGXUtils {
/** Closures used for SGX tests should be written here (not in Tests) as SGXRunner is using
* classLoader to find the appropriate anonymous function */
val filterEvenNumFunc = (t: Int) => t % 2 == 0

val mapIncrementOneFunc = (v: Int) => v + 1
val mapMultiplyByTwoFunc = (v: Int) => v * 2

val mapPartitionsSum = (iter: Iterator[Int]) => Iterator(iter.sum)

val mapPartitionsWithIndex = (split: Int, iter: Iterator[Int]) => Iterator((split, iter.sum))
val mapToList = (iter: Array[Int]) => iter.toList

val flatMapOneToVal: (Int) => TraversableOnce[Int] = (x: Int) => 1 to x
}
Loading

0 comments on commit 8f51466

Please sign in to comment.