Skip to content

Commit

Permalink
In memory shuffle (cherry-picked from #135)
Browse files Browse the repository at this point in the history
  • Loading branch information
ankurdave committed Mar 28, 2014
1 parent f36e576 commit 5ec645d
Show file tree
Hide file tree
Showing 6 changed files with 48 additions and 20 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ import org.apache.spark.rdd.RDDCheckpointData
import org.apache.spark.serializer.Serializer
import org.apache.spark.storage._
import org.apache.spark.util.{MetadataCleaner, MetadataCleanerType, TimeStampedHashMap}
import java.nio.ByteBuffer

private[spark] object ShuffleMapTask {

Expand Down Expand Up @@ -168,7 +169,11 @@ private[spark] class ShuffleMapTask(
var totalBytes = 0L
var totalTime = 0L
val compressedSizes: Array[Byte] = shuffle.writers.map { writer: BlockObjectWriter =>
writer.commit()
// writer.commit()
val bytes = writer.commit()
if (bytes != null) {
blockManager.putBytes(writer.blockId, ByteBuffer.wrap(bytes), StorageLevel.MEMORY_ONLY_SER, tellMaster = false)
}
writer.close()
val size = writer.fileSegment().length
totalBytes += size
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ private[spark] class BlockManager(

private val blockInfo = new TimeStampedHashMap[BlockId, BlockInfo]

private[storage] val memoryStore: BlockStore = new MemoryStore(this, maxMemory)
private[storage] val memoryStore = new MemoryStore(this, maxMemory)
private[storage] val diskStore = new DiskStore(this, diskBlockManager)

// If we use Netty for shuffle, start a new Netty-based shuffle sender service.
Expand Down Expand Up @@ -293,7 +293,7 @@ private[spark] class BlockManager(
* never deletes (recent) items.
*/
def getLocalFromDisk(blockId: BlockId, serializer: Serializer): Option[Iterator[Any]] = {
diskStore.getValues(blockId, serializer).orElse(
memoryStore.getValues(blockId, serializer).orElse(
sys.error("Block " + blockId + " not found on disk, though it should be"))
}

Expand All @@ -313,7 +313,7 @@ private[spark] class BlockManager(
// As an optimization for map output fetches, if the block is for a shuffle, return it
// without acquiring a lock; the disk store never deletes (recent) items so this should work
if (blockId.isShuffle) {
diskStore.getBytes(blockId) match {
memoryStore.getBytes(blockId) match {
case Some(bytes) =>
Some(bytes)
case None =>
Expand Down Expand Up @@ -831,7 +831,7 @@ private[spark] class BlockManager(
if (info != null) info.synchronized {
// Removals are idempotent in disk store and memory store. At worst, we get a warning.
val removedFromMemory = memoryStore.remove(blockId)
val removedFromDisk = diskStore.remove(blockId)
val removedFromDisk = false //diskStore.remove(blockId)
if (!removedFromMemory && !removedFromDisk) {
logWarning("Block " + blockId + " could not be removed as it was not found in either " +
"the disk or memory store")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

package org.apache.spark.storage

import java.io.{FileOutputStream, File, OutputStream}
import java.io.{ByteArrayOutputStream, FileOutputStream, File, OutputStream}
import java.nio.channels.FileChannel

import it.unimi.dsi.fastutil.io.FastBufferedOutputStream
Expand All @@ -44,7 +44,7 @@ private[spark] abstract class BlockObjectWriter(val blockId: BlockId) {
* Flush the partial writes and commit them as a single atomic block. Return the
* number of bytes written for this commit.
*/
def commit(): Long
def commit(): Array[Byte]

/**
* Reverts writes that haven't been flushed yet. Callers should invoke this function
Expand Down Expand Up @@ -106,7 +106,7 @@ private[spark] class DiskBlockObjectWriter(
/** The file channel, used for repositioning / truncating the file. */
private var channel: FileChannel = null
private var bs: OutputStream = null
private var fos: FileOutputStream = null
private var fos: ByteArrayOutputStream = null
private var ts: TimeTrackingOutputStream = null
private var objOut: SerializationStream = null
private val initialPosition = file.length()
Expand All @@ -115,9 +115,8 @@ private[spark] class DiskBlockObjectWriter(
private var _timeWriting = 0L

override def open(): BlockObjectWriter = {
fos = new FileOutputStream(file, true)
fos = new ByteArrayOutputStream()
ts = new TimeTrackingOutputStream(fos)
channel = fos.getChannel()
lastValidPosition = initialPosition
bs = compressStream(new FastBufferedOutputStream(ts, bufferSize))
objOut = serializer.newInstance().serializeStream(bs)
Expand All @@ -130,9 +129,6 @@ private[spark] class DiskBlockObjectWriter(
if (syncWrites) {
// Force outstanding writes to disk and track how long it takes
objOut.flush()
val start = System.nanoTime()
fos.getFD.sync()
_timeWriting += System.nanoTime() - start
}
objOut.close()

Expand All @@ -149,18 +145,18 @@ private[spark] class DiskBlockObjectWriter(

override def isOpen: Boolean = objOut != null

override def commit(): Long = {
override def commit(): Array[Byte] = {
if (initialized) {
// NOTE: Because Kryo doesn't flush the underlying stream we explicitly flush both the
// serializer stream and the lower level stream.
objOut.flush()
bs.flush()
val prevPos = lastValidPosition
lastValidPosition = channel.position()
lastValidPosition - prevPos
lastValidPosition = fos.size()
fos.toByteArray
} else {
// lastValidPosition is zero if stream is uninitialized
lastValidPosition
null
}
}

Expand All @@ -170,7 +166,7 @@ private[spark] class DiskBlockObjectWriter(
// truncate the file to the last valid position.
objOut.flush()
bs.flush()
channel.truncate(lastValidPosition)
throw new UnsupportedOperationException("Revert temporarily broken due to in memory shuffle code changes.")
}
}

Expand All @@ -182,7 +178,7 @@ private[spark] class DiskBlockObjectWriter(
}

override def fileSegment(): FileSegment = {
new FileSegment(file, initialPosition, bytesWritten)
new FileSegment(null, initialPosition, bytesWritten)
}

// Only valid if called after close()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ import java.util.LinkedHashMap
import scala.collection.mutable.ArrayBuffer

import org.apache.spark.util.{SizeEstimator, Utils}
import org.apache.spark.serializer.Serializer

/**
* Stores blocks in memory, either as ArrayBuffers of deserialized Java objects or as
Expand Down Expand Up @@ -119,6 +120,14 @@ private class MemoryStore(blockManager: BlockManager, maxMemory: Long)
}
}

/**
* A version of getValues that allows a custom serializer. This is used as part of the
* shuffle short-circuit code.
*/
def getValues(blockId: BlockId, serializer: Serializer): Option[Iterator[Any]] = {
getBytes(blockId).map(bytes => blockManager.dataDeserialize(blockId, bytes, serializer))
}

override def remove(blockId: BlockId): Boolean = {
entries.synchronized {
val entry = entries.remove(blockId)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -187,6 +187,17 @@ class ShuffleBlockManager(blockManager: BlockManager) extends Logging {
}
})
}

def removeAllShuffleStuff() {
for (state <- shuffleStates.values;
group <- state.allFileGroups;
(mapId, _) <- group.mapIdToIndex.iterator;
reducerId <- 0 until group.files.length) {
val blockId = new ShuffleBlockId(group.shuffleId, mapId, reducerId)
blockManager.removeBlock(blockId, tellMaster = false)
}
shuffleStates.clear()
}
}

private[spark]
Expand All @@ -200,7 +211,7 @@ object ShuffleBlockManager {
* Stores the absolute index of each mapId in the files of this group. For instance,
* if mapId 5 is the first block in each file, mapIdToIndex(5) = 0.
*/
private val mapIdToIndex = new PrimitiveKeyOpenHashMap[Int, Int]()
val mapIdToIndex = new PrimitiveKeyOpenHashMap[Int, Int]()

/**
* Stores consecutive offsets of blocks into each reducer file, ordered by position in the file.
Expand Down
7 changes: 7 additions & 0 deletions graphx/src/main/scala/org/apache/spark/graphx/Pregel.scala
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ package org.apache.spark.graphx

import scala.reflect.ClassTag
import org.apache.spark.Logging
import org.apache.spark.SparkEnv


/**
Expand Down Expand Up @@ -143,6 +144,12 @@ object Pregel extends Logging {
// hides oldMessages (depended on by newVerts), newVerts (depended on by messages), and the
// vertices of prevG (depended on by newVerts, oldMessages, and the vertices of g).
activeMessages = messages.count()

// Very ugly code to clear the in-memory shuffle data
messages.foreachPartition { iter =>
SparkEnv.get.blockManager.shuffleBlockManager.removeAllShuffleStuff()
}

// Unpersist the RDDs hidden by newly-materialized RDDs
oldMessages.unpersist(blocking=false)
newVerts.unpersist(blocking=false)
Expand Down

0 comments on commit 5ec645d

Please sign in to comment.