Skip to content

Commit

Permalink
Supporting byPassMergeSort shuffle for SGX (#17)
Browse files Browse the repository at this point in the history
* Introducing new SHUFFLE_MAP type for bypassMerge sort shuffle

* Adding numOfPartitions as part of SGX RDD (needed by the shuffle partitioner)

* Introducing SGX partitioner - hash-based for now (to be changed by any pseudorandom function we want)

* Introducing sgxWrite interface as part of Shuffle writter
When SGX is enable we use this path instead
The method takes as arguments both the encrypted records AND the record-partition mapping (recordMapping) as return by the enclave worker

* Handling SHUFFLE_MAP_BYPASS task type in SGX worker - return a record-partition mapping iterator
* ShuffleMapTask checks if SGX is enabled and calls the appropriate shuffle interface

* ByPassSort shuffle test case going through SGXWorker now working #15
Multiple partitions supported
  • Loading branch information
pgaref authored Aug 1, 2019
1 parent 8f51466 commit c250aa7
Show file tree
Hide file tree
Showing 12 changed files with 205 additions and 43 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -17,23 +17,13 @@

package org.apache.spark.shuffle.sort;

import com.google.common.annotations.VisibleForTesting;
import com.google.common.io.Closeables;
import java.io.File;
import java.io.FileInputStream;
import java.io.FileOutputStream;
import java.io.IOException;
import javax.annotation.Nullable;

import scala.None$;
import scala.Option;
import scala.Product2;
import scala.Tuple2;
import scala.collection.Iterator;

import com.google.common.annotations.VisibleForTesting;
import com.google.common.io.Closeables;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import org.apache.spark.Partitioner;
import org.apache.spark.ShuffleDependency;
import org.apache.spark.SparkConf;
Expand All @@ -45,8 +35,19 @@
import org.apache.spark.serializer.SerializerInstance;
import org.apache.spark.shuffle.IndexShuffleBlockResolver;
import org.apache.spark.shuffle.ShuffleWriter;
import org.apache.spark.storage.*;
import org.apache.spark.storage.BlockId;
import org.apache.spark.storage.BlockManager;
import org.apache.spark.storage.DiskBlockObjectWriter;
import org.apache.spark.storage.FileSegment;
import org.apache.spark.storage.TempShuffleBlockId;
import org.apache.spark.util.Utils;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import scala.None$;
import scala.Option;
import scala.Product2;
import scala.Tuple2;
import scala.collection.Iterator;

/**
* This class implements sort-based shuffle's hash-style shuffle fallback path. This write path
Expand All @@ -69,6 +70,13 @@
* refactored into its own class in order to reduce code complexity; see SPARK-7855 for details.
* <p>
* There have been proposals to completely remove this code path; see SPARK-6026 for details.
*
* The bypassMergeThreshold parameter and associated use of a hash-ish shuffle when the number of
* partitions is less than this, is basically a workaround for SparkSQL, because the fact that the
* sort-based shuffle stores non-serialized objects is a deal-breaker for SparkSQL,
* which re-uses objects. Once the sort-based shuffle is changed to store serialized objects,
* we should never be secretly doing hash-ish shuffle even when the user has specified to use
* sort-based shuffle (because of its otherwise worse performance).
*/
final class BypassMergeSortShuffleWriter<K, V> extends ShuffleWriter<K, V> {

Expand Down Expand Up @@ -170,6 +178,59 @@ public void write(Iterator<Product2<K, V>> records) throws IOException {
mapStatus = MapStatus$.MODULE$.apply(blockManager.shuffleServerId(), partitionLengths);
}

@Override
public void sgxWrite(Iterator<Product2<K, V>> records, java.util.Map<K, Integer> recordMapping) throws IOException {
assert (partitionWriters == null);
logger.debug("sgxWriter...");
if (!records.hasNext()) {
partitionLengths = new long[numPartitions];
shuffleBlockResolver.writeIndexFileAndCommit(shuffleId, mapId, partitionLengths, null);
mapStatus = MapStatus$.MODULE$.apply(blockManager.shuffleServerId(), partitionLengths);
return;
}
final SerializerInstance serInstance = serializer.newInstance();
final long openStartTime = System.nanoTime();
partitionWriters = new DiskBlockObjectWriter[numPartitions];
partitionWriterSegments = new FileSegment[numPartitions];
for (int i = 0; i < numPartitions; i++) {
final Tuple2<TempShuffleBlockId, File> tempShuffleBlockIdPlusFile =
blockManager.diskBlockManager().createTempShuffleBlock();
final File file = tempShuffleBlockIdPlusFile._2();
final BlockId blockId = tempShuffleBlockIdPlusFile._1();
partitionWriters[i] =
blockManager.getDiskWriter(blockId, file, serInstance, fileBufferSize, writeMetrics);
}
// Creating the file to write to and creating a disk writer both involve interacting with
// the disk, and can take a long time in aggregate when we open many files, so should be
// included in the shuffle write time.
writeMetrics.incWriteTime(System.nanoTime() - openStartTime);

// Using the encrypted key mapping as propagated by the enclave
while (records.hasNext()) {
final Product2<K, V> record = records.next();
final K key = record._1();
partitionWriters[recordMapping.get(key)].write(key, record._2());
}

for (int i = 0; i < numPartitions; i++) {
final DiskBlockObjectWriter writer = partitionWriters[i];
partitionWriterSegments[i] = writer.commitAndGet();
writer.close();
}

File output = shuffleBlockResolver.getDataFile(shuffleId, mapId);
File tmp = Utils.tempFileWith(output);
try {
partitionLengths = writePartitionedFile(tmp);
shuffleBlockResolver.writeIndexFileAndCommit(shuffleId, mapId, partitionLengths, tmp);
} finally {
if (tmp.exists() && !tmp.delete()) {
logger.error("Error while deleting temp file {}", tmp.getAbsolutePath());
}
}
mapStatus = MapStatus$.MODULE$.apply(blockManager.shuffleServerId(), partitionLengths);
}

@VisibleForTesting
long[] getPartitionLengths() {
return partitionLengths;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,32 +17,33 @@

package org.apache.spark.shuffle.sort;

import javax.annotation.Nullable;
import java.io.*;
import java.nio.channels.FileChannel;
import java.util.Iterator;

import scala.Option;
import scala.Product2;
import scala.collection.JavaConverters;
import scala.reflect.ClassTag;
import scala.reflect.ClassTag$;

import com.google.common.annotations.VisibleForTesting;
import com.google.common.io.ByteStreams;
import com.google.common.io.Closeables;
import com.google.common.io.Files;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import org.apache.spark.*;
import java.io.BufferedOutputStream;
import java.io.ByteArrayOutputStream;
import java.io.File;
import java.io.FileInputStream;
import java.io.FileOutputStream;
import java.io.IOException;
import java.io.InputStream;
import java.io.OutputStream;
import java.nio.channels.FileChannel;
import java.util.Iterator;
import javax.annotation.Nullable;
import org.apache.commons.io.output.CloseShieldOutputStream;
import org.apache.commons.io.output.CountingOutputStream;
import org.apache.spark.Partitioner;
import org.apache.spark.ShuffleDependency;
import org.apache.spark.SparkConf;
import org.apache.spark.TaskContext;
import org.apache.spark.annotation.Private;
import org.apache.spark.executor.ShuffleWriteMetrics;
import org.apache.spark.internal.config.package$;
import org.apache.spark.io.CompressionCodec;
import org.apache.spark.io.CompressionCodec$;
import org.apache.spark.io.NioBufferedFileInputStream;
import org.apache.commons.io.output.CloseShieldOutputStream;
import org.apache.commons.io.output.CountingOutputStream;
import org.apache.spark.memory.TaskMemoryManager;
import org.apache.spark.network.util.LimitedInputStream;
import org.apache.spark.scheduler.MapStatus;
Expand All @@ -55,7 +56,13 @@
import org.apache.spark.storage.TimeTrackingOutputStream;
import org.apache.spark.unsafe.Platform;
import org.apache.spark.util.Utils;
import org.apache.spark.internal.config.package$;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import scala.Option;
import scala.Product2;
import scala.collection.JavaConverters;
import scala.reflect.ClassTag;
import scala.reflect.ClassTag$;

@Private
public class UnsafeShuffleWriter<K, V> extends ShuffleWriter<K, V> {
Expand Down Expand Up @@ -207,6 +214,12 @@ public void write(scala.collection.Iterator<Product2<K, V>> records) throws IOEx
}
}

@Override
public void sgxWrite(scala.collection.Iterator<Product2<K, V>> records,
java.util.Map<K, Integer> recordMapping) throws IOException {
throw new RuntimeException("Not implemented yet!");
}

private void open() {
assert (sorter == null);
sorter = new ShuffleExternalSorter(
Expand Down
27 changes: 27 additions & 0 deletions core/src/main/scala/org/apache/spark/Partitioner.scala
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,33 @@ class HashPartitioner(partitions: Int) extends Partitioner {
override def hashCode: Int = numPartitions
}


/**
* A [[org.apache.spark.Partitioner]] that implements hash-based partitioning using
* Java's `Object.hashCode`.
*
* TODO: Make sure that the partitioner satisfies the pseudorandom properties we need
*/
class SGXPartitioner(partitions: Int) extends Partitioner {
require(partitions >= 0, s"Number of partitions ($partitions) cannot be negative.")

def numPartitions: Int = partitions

def getPartition(key: Any): Int = key match {
case null => 0
case _ => Utils.nonNegativeMod(key.hashCode, numPartitions)
}

override def equals(other: Any): Boolean = other match {
case h: HashPartitioner =>
h.numPartitions == numPartitions
case _ =>
false
}

override def hashCode: Int = numPartitions
}

/**
* A [[org.apache.spark.Partitioner]] that partitions sortable records by range into roughly
* equal ranges. The ranges are determined by sampling the content of the RDD passed in.
Expand Down
10 changes: 7 additions & 3 deletions core/src/main/scala/org/apache/spark/api/sgx/SGXBaseRunner.scala
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,8 @@ private[spark] abstract class SGXBaseRunner[IN: ClassTag, OUT: ClassTag](

def compute(inputIterator: Iterator[IN],
partitionIndex: Int,
context: TaskContext): Iterator[OUT] = {
context: TaskContext,
numOfPartitions: Int = 1): Iterator[OUT] = {
val startTime = System.currentTimeMillis
val env = SparkEnv.get
val localdir = env.blockManager.diskBlockManager.localDirs.map(f => f.getPath()).mkString(",")
Expand All @@ -62,7 +63,7 @@ private[spark] abstract class SGXBaseRunner[IN: ClassTag, OUT: ClassTag](
val releasedOrClosed = new AtomicBoolean(false)

// Start a thread to feed the process input from our parent's iterator
val writerThread = sgxWriterThread(env, worker, inputIterator, partitionIndex, context)
val writerThread = sgxWriterThread(env, worker, inputIterator, numOfPartitions, partitionIndex, context)
// Add task completion Listener
context.addTaskCompletionListener[Unit] { _ =>
writerThread.shutdownOnTaskCompletion()
Expand All @@ -88,6 +89,7 @@ private[spark] abstract class SGXBaseRunner[IN: ClassTag, OUT: ClassTag](
protected def sgxWriterThread(env: SparkEnv,
worker: Socket,
inputIterator: Iterator[IN],
numOfPartitions: Int,
partitionIndex: Int,
context: TaskContext): WriterIterator

Expand All @@ -98,6 +100,7 @@ private[spark] abstract class SGXBaseRunner[IN: ClassTag, OUT: ClassTag](
abstract class WriterIterator(env: SparkEnv,
worker: Socket,
inputIterator: Iterator[IN],
numOfPartitions: Int,
partitionIndex: Int,
context: TaskContext)
extends Thread(s"stdout writer for SGXRunner TID:${context.taskAttemptId()}") {
Expand Down Expand Up @@ -146,6 +149,7 @@ private[spark] abstract class SGXBaseRunner[IN: ClassTag, OUT: ClassTag](
// Write out the TaskContextInfo
dataOut.writeInt(boundPort)
dataOut.writeInt(context.stageId())
dataOut.writeInt(numOfPartitions)
dataOut.writeInt(context.partitionId())
dataOut.writeInt(context.attemptNumber())
dataOut.writeLong(context.taskAttemptId())
Expand Down Expand Up @@ -319,7 +323,7 @@ private[spark] abstract class SGXBaseRunner[IN: ClassTag, OUT: ClassTag](
private[spark] object SGXFunctionType {
val NON_UDF = 0
val PIPELINED = 1
val SHUFFLE_MAP = 2
val SHUFFLE_MAP_BYPASS = 2
val SHUFFLE_REDUCE = 3
val BATCHED_UDF = 100
def toString(sgxFuncType: Int): String = sgxFuncType match {
Expand Down
3 changes: 2 additions & 1 deletion core/src/main/scala/org/apache/spark/api/sgx/SGXRDD.scala
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,8 @@ private[spark] class SGXRDD(parent: RDD[_],
val asJavaRDD: JavaRDD[Array[Byte]] = JavaRDD.fromRDD(this)

override def compute(split: Partition, context: TaskContext): Iterator[Array[Byte]] = {
val runner = SGXRunner(func)
firstParent.iterator(split, context)
val runner = SGXRunner(func, if (parent.funcBuff.isEmpty) SGXFunctionType.NON_UDF else SGXFunctionType.PIPELINED, parent.funcBuff)
runner.compute(firstParent.iterator(split, context), split.index, context)
}

Expand Down
6 changes: 4 additions & 2 deletions core/src/main/scala/org/apache/spark/api/sgx/SGXRunner.scala
Original file line number Diff line number Diff line change
Expand Up @@ -45,8 +45,10 @@ private[spark] class SGXRunner(func: (Iterator[Any]) => Any, funcType: Int, func
override protected def sgxWriterThread(env: SparkEnv,
worker: Socket,
inputIterator: Iterator[Array[Byte]],
partitionIndex: Int, context: TaskContext): WriterIterator = {
new WriterIterator(env, worker, inputIterator, partitionIndex, context) {
numOfPartitions: Int,
partitionIndex: Int,
context: TaskContext): WriterIterator = {
new WriterIterator(env, worker, inputIterator, numOfPartitions, partitionIndex, context) {
/** Writes a command section to the stream connected to the SGX worker */
override protected def writeFunction(dataOut: DataOutputStream): Unit = {
logInfo(s"Ser ${funcs.size + 1} closures")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ import org.apache.spark.api.sgx.{SGXException, SGXFunctionType, SGXRDD, SpecialS
import org.apache.spark.internal.Logging
import org.apache.spark.serializer._
import org.apache.spark.util.Utils
import org.apache.spark.{SparkConf, SparkException, TaskContext}
import org.apache.spark.{SGXPartitioner, SparkConf, SparkException, TaskContext}

import scala.collection.mutable
import scala.reflect.ClassTag
Expand All @@ -52,6 +52,7 @@ private[spark] class SGXWorker(closuseSer: SerializerInstance, dataSer: Serializ
val boundPort = inSock.readInt()
val taskContext = TaskContext.get()
val stageId = inSock.readInt()
val numOfPartitions = inSock.readInt()
val partitionId = inSock.readInt()
val attemptId = inSock.readInt()
val taskAttemptId = inSock.readLong()
Expand All @@ -72,6 +73,21 @@ private[spark] class SGXWorker(closuseSer: SerializerInstance, dataSer: Serializ
logInfo(s"Executing ${funcArray.size} (pipelined) funcs")

eval_type match {

case SGXFunctionType.SHUFFLE_MAP_BYPASS =>
logDebug(s"ShuffleMap Bypass #Partitions ${numOfPartitions}")
val iterator = new ReaderIterator(inSock, dataSer).asInstanceOf[Iterator[_ <: Product2[Any, Any]]]
val sgxPartitioner = new SGXPartitioner(numOfPartitions)
// Mapping of encrypted keys to partitions (needed by the shuffler Writter)
val keyMapping = scala.collection.mutable.Map[Any, Any]()
while (iterator.hasNext) {
val record = iterator.next()
keyMapping(record._1) = sgxPartitioner.getPartition(record._1)
}
SGXRDD.writeIteratorToStream[Any](keyMapping.toIterator, dataSer, outSock)
outSock.writeInt(SpecialSGXChars.END_OF_DATA_SECTION)
outSock.flush()

case SGXFunctionType.NON_UDF =>
// Read Iterator
val iterator = new ReaderIterator(inSock, dataSer)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,13 +21,16 @@ import java.lang.management.ManagementFactory
import java.nio.ByteBuffer
import java.util.Properties

import scala.collection.JavaConverters._
import scala.language.existentials

import org.apache.spark._
import org.apache.spark.api.sgx.{SGXFunctionType, SGXRunner}
import org.apache.spark.broadcast.Broadcast
import org.apache.spark.internal.Logging
import org.apache.spark.rdd.RDD
import org.apache.spark.shuffle.ShuffleWriter
import org.apache.spark.util.SGXUtils

/**
* A ShuffleMapTask divides the elements of an RDD into multiple buckets (based on a partitioner
Expand Down Expand Up @@ -96,7 +99,18 @@ private[spark] class ShuffleMapTask(
try {
val manager = SparkEnv.get.shuffleManager
writer = manager.getWriter[Any, Any](dep.shuffleHandle, partitionId, context)
writer.write(rdd.iterator(partition, context).asInstanceOf[Iterator[_ <: Product2[Any, Any]]])

if (SparkEnv.get.conf.isSGXWorkerEnabled()) {
val runner = SGXRunner(SGXUtils.toIteratorSizeSGXFunc, SGXFunctionType.SHUFFLE_MAP_BYPASS)
// Need to explicitly set the number of partitions here
val keyMapping = runner.compute(rdd.iterator(partition, context).asInstanceOf[Iterator[Array[Byte]]],
partitionId, context, dep.partitioner.numPartitions).asInstanceOf[Iterator[_ <: Product2[Any, Any]]]
val keyMap = scala.collection.mutable.Map[Any, Integer]()
for (i <- keyMapping) keyMap(i._1) = i._2.asInstanceOf[Integer]
writer.sgxWrite(rdd.iterator(partition, context).asInstanceOf[Iterator[_ <: Product2[Any, Any]]], keyMap.asJava)
} else {
writer.write(rdd.iterator(partition, context).asInstanceOf[Iterator[_ <: Product2[Any, Any]]])
}
writer.stop(success = true).get
} catch {
case e: Exception =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,10 @@ private[spark] abstract class ShuffleWriter[K, V] {
@throws[IOException]
def write(records: Iterator[Product2[K, V]]): Unit

/** Write a sequence of encrypted records to this task's output following the given recordKey-Partition mapping */
@throws[IOException]
def sgxWrite(records: Iterator[Product2[K, V]], recordMapping: java.util.Map[K, Integer]): Unit

/** Close this writer, passing along whether the map completed */
def stop(success: Boolean): Option[MapStatus]
}
Loading

0 comments on commit c250aa7

Please sign in to comment.