diff --git a/spark-on-angel/graph/src/main/scala/com/tencent/angel/graph/community/copra/COPRAGraphPartition.scala b/spark-on-angel/graph/src/main/scala/com/tencent/angel/graph/community/copra/COPRAGraphPartition.scala index 6c7cca172..1e5ef1064 100644 --- a/spark-on-angel/graph/src/main/scala/com/tencent/angel/graph/community/copra/COPRAGraphPartition.scala +++ b/spark-on-angel/graph/src/main/scala/com/tencent/angel/graph/community/copra/COPRAGraphPartition.scala @@ -17,10 +17,11 @@ package com.tencent.angel.graph.community.copra +import com.tencent.angel.graph.psf.triangle.NeighborsFloatAttrsElement import it.unimi.dsi.fastutil.floats.FloatArrayList import it.unimi.dsi.fastutil.ints.IntArrayList import it.unimi.dsi.fastutil.longs.{Long2ObjectOpenHashMap, LongArrayList} -import com.tencent.angel.graph.psf.triangle.NeighborsFloatAttrsElement + import scala.collection.mutable import scala.collection.mutable.ArrayBuffer import scala.util.Random @@ -38,8 +39,9 @@ class COPRAGraphPartition(index: Int, def analysis(): Int = neighbors.length def initMsgs(model: COPRAPSModel): Unit = { - keys.sliding(batchSize, batchSize).foreach {partKeys => - initPartMsgs(partKeys, model)} + keys.sliding(batchSize, batchSize).foreach { partKeys => + initPartMsgs(partKeys, model) + } } def initPartMsgs(partKeys: Array[Long], model: COPRAPSModel): Unit = { @@ -54,7 +56,7 @@ class COPRAGraphPartition(index: Int, println(s"partition $index: ---------- iteration $iteration starts ----------") keys.indices.sliding(batchSize, batchSize).foreach { nodesIndex => val beforeCalcPullNodesTs = System.currentTimeMillis() - val pullNodes = neighbors.slice(indptr(nodesIndex.head), indptr(nodesIndex.last+1)).distinct + val pullNodes = neighbors.slice(indptr(nodesIndex.head), indptr(nodesIndex.last + 1)).distinct println(s"partition $index: calculating pull nodes cost ${System.currentTimeMillis() - beforeCalcPullNodesTs} ms") val beforePullCoeTs = System.currentTimeMillis() @@ -101,17 +103,19 @@ class COPRAGraphPartition(index: Int, while (j < indptr(idx + 1)) { assert(comInMsgs.containsKey(neighbors(j)), s"Key ${neighbors(j)} is not in the keySet of comInMsgs.") val comGet = comInMsgs.get(neighbors(j)) - val neis = comGet.getNeighborIds - val coes = comGet.getAttrs - var x = 0 - while (x < comGet.getNodesNum) { - val label = neis(x) - val coe = coes(x) - val t = temp.getOrElse(label, 0f) - temp += ((label, t + coe * weights(j))) - x += 1 + if (comGet != null) { + val neis = comGet.getNeighborIds + val coes = comGet.getAttrs + var x = 0 + while (x < comGet.getNodesNum) { + val label = neis(x) + val coe = coes(x) + val t = temp.getOrElse(label, 0f) + temp += ((label, t + coe * weights(j))) + x += 1 + } + j += 1 } - j += 1 } // temp.remove(-1) if (temp.isEmpty) { @@ -127,7 +131,7 @@ class COPRAGraphPartition(index: Int, } // normalize val norm = newAttr.values.sum - newAttr.map { case (node, coe) => (node, coe / norm)}.toArray + newAttr.map { case (node, coe) => (node, coe / norm) }.toArray } } @@ -143,7 +147,7 @@ object COPRAGraphPartition { val indptr = new IntArrayList() val keys = new LongArrayList() val neighbors = new LongArrayList() - val keyLabels = ArrayBuffer[Array[(Long,Float)]]() + val keyLabels = ArrayBuffer[Array[(Long, Float)]]() val weights = new FloatArrayList() indptr.add(0) @@ -165,7 +169,7 @@ object COPRAGraphPartition { val neighborsArray = neighbors.toLongArray() val indicesArray = keysArray.union(neighborsArray).distinct - new COPRAGraphPartition(index, numMaxCommunities,preserveRate, + new COPRAGraphPartition(index, numMaxCommunities, preserveRate, keysArray, indptr.toIntArray(), neighborsArray,