From c8a7bb1ac6c7acf664948bd1aea042780a1c6689 Mon Sep 17 00:00:00 2001 From: Ankur Dave Date: Sat, 19 Apr 2014 20:10:56 -0700 Subject: [PATCH] Optionally checkpoint in Pregel --- .../scala/org/apache/spark/graphx/Pregel.scala | 9 ++++++++- .../org/apache/spark/graphx/lib/Analytics.scala | 10 ++++++++-- .../org/apache/spark/graphx/lib/PageRank.scala | 16 ++++++++-------- 3 files changed, 24 insertions(+), 11 deletions(-) diff --git a/graphx/src/main/scala/org/apache/spark/graphx/Pregel.scala b/graphx/src/main/scala/org/apache/spark/graphx/Pregel.scala index 70d2c64c2..7fa4967d6 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/Pregel.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/Pregel.scala @@ -114,12 +114,16 @@ object Pregel extends Logging { (graph: Graph[VD, ED], initialMsg: A, maxIterations: Int = Int.MaxValue, - activeDirection: EdgeDirection = EdgeDirection.Either) + activeDirection: EdgeDirection = EdgeDirection.Either, + checkpoint: Boolean = false) (vprog: (VertexId, VD, A) => VD, sendMsg: EdgeTriplet[VD, ED] => Iterator[(VertexId, A)], mergeMsg: (A, A) => A) : Graph[VD, ED] = { + if (checkpoint) { + graph.edges.checkpoint() + } var g = graph.mapVertices((vid, vdata) => vprog(vid, vdata, initialMsg)).cache() // compute the messages var messages = g.mapReduceTriplets(sendMsg, mergeMsg) @@ -134,6 +138,9 @@ object Pregel extends Logging { prevG = g g = g.outerJoinVertices(newVerts) { (vid, old, newOpt) => newOpt.getOrElse(old) } g.cache() + if (checkpoint) { + g.vertices.checkpoint() + } val oldMessages = messages // Send new messages. Vertices that didn't get any messages don't appear in newVerts, so don't diff --git a/graphx/src/main/scala/org/apache/spark/graphx/lib/Analytics.scala b/graphx/src/main/scala/org/apache/spark/graphx/lib/Analytics.scala index 4a3c3344d..9ed14eeac 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/lib/Analytics.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/lib/Analytics.scala @@ -59,6 +59,7 @@ object Analytics extends Logging { var numEPart = 4 var partitionStrategy: Option[PartitionStrategy] = None var numIterOpt: Option[Int] = None + var checkpointDirOpt: Option[String] = None options.foreach{ case ("tol", v) => tol = v.toFloat @@ -66,6 +67,7 @@ object Analytics extends Logging { case ("numEPart", v) => numEPart = v.toInt case ("partStrategy", v) => partitionStrategy = Some(pickPartitioner(v)) case ("numIter", v) => numIterOpt = Some(v.toInt) + case ("checkpointDir", v) => checkpointDirOpt = Some(v) case (opt, _) => throw new IllegalArgumentException("Invalid option: " + opt) } @@ -74,6 +76,10 @@ object Analytics extends Logging { println("======================================") val sc = new SparkContext(host, "PageRank(" + fname + ")", conf) + checkpointDirOpt match { + case Some(checkpointDir) => sc.setCheckpointDir(checkpointDir) + case None => {} + } val unpartitionedGraph = GraphLoader.edgeListFile(sc, fname, minEdgePartitions = numEPart).cache() @@ -83,8 +89,8 @@ object Analytics extends Logging { println("GRAPHX: Number of edges " + graph.edges.count) val pr = (numIterOpt match { - case Some(numIter) => PageRank.run(graph, numIter) - case None => PageRank.runUntilConvergence(graph, tol) + case Some(numIter) => PageRank.run(graph, numIter, checkpoint = checkpointDirOpt.nonEmpty) + case None => PageRank.runUntilConvergence(graph, tol, checkpoint = checkpointDirOpt.nonEmpty) }).vertices.cache() println("GRAPHX: Total rank: " + pr.map(_._2).reduce(_ + _)) diff --git a/graphx/src/main/scala/org/apache/spark/graphx/lib/PageRank.scala b/graphx/src/main/scala/org/apache/spark/graphx/lib/PageRank.scala index 614555a05..d3851cdfd 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/lib/PageRank.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/lib/PageRank.scala @@ -77,8 +77,8 @@ object PageRank extends Logging { * */ def run[VD: ClassTag, ED: ClassTag]( - graph: Graph[VD, ED], numIter: Int, resetProb: Double = 0.15): Graph[Double, Double] = - { + graph: Graph[VD, ED], numIter: Int, resetProb: Double = 0.15, checkpoint: Boolean = false) + : Graph[Double, Double] = { // Initialize the pagerankGraph with each edge attribute having // weight 1/outDegree and each vertex with attribute 1.0. val pagerankGraph: Graph[Double, Double] = graph @@ -101,8 +101,8 @@ object PageRank extends Logging { val initialMessage = 0.0 // Execute pregel for a fixed number of iterations. - Pregel(pagerankGraph, initialMessage, numIter, activeDirection = EdgeDirection.Out)( - vertexProgram, sendMessage, messageCombiner) + Pregel(pagerankGraph, initialMessage, numIter, activeDirection = EdgeDirection.Out, + checkpoint = checkpoint)(vertexProgram, sendMessage, messageCombiner) } /** @@ -120,8 +120,8 @@ object PageRank extends Logging { * containing the normalized weight. */ def runUntilConvergence[VD: ClassTag, ED: ClassTag]( - graph: Graph[VD, ED], tol: Double, resetProb: Double = 0.15): Graph[Double, Double] = - { + graph: Graph[VD, ED], tol: Double, resetProb: Double = 0.15, checkpoint: Boolean = false) + : Graph[Double, Double] = { // Initialize the pagerankGraph with each edge attribute // having weight 1/outDegree and each vertex with attribute 1.0. val pagerankGraph: Graph[(Double, Double), Double] = graph @@ -157,8 +157,8 @@ object PageRank extends Logging { val initialMessage = resetProb / (1.0 - resetProb) // Execute a dynamic version of Pregel. - Pregel(pagerankGraph, initialMessage, activeDirection = EdgeDirection.Out)( - vertexProgram, sendMessage, messageCombiner) + Pregel(pagerankGraph, initialMessage, activeDirection = EdgeDirection.Out, + checkpoint = checkpoint)(vertexProgram, sendMessage, messageCombiner) .mapVertices((vid, attr) => attr._1) } // end of deltaPageRank }