Skip to content

Commit

Permalink
Merge pull request #87 from amplab/saveload
Browse files Browse the repository at this point in the history
allow saving weights to file and loading weights from file, add NDArr…
  • Loading branch information
pcmoritz committed Feb 21, 2016
2 parents 0ba7d7d + 88acb0c commit b811223
Show file tree
Hide file tree
Showing 5 changed files with 75 additions and 27 deletions.
19 changes: 19 additions & 0 deletions src/main/java/libs/JavaNDArray.java
Original file line number Diff line number Diff line change
Expand Up @@ -154,4 +154,23 @@ private void next(int[] indices) {
}
indices[axis] += 1;
}

public boolean equals(JavaNDArray that, float tol) {
if (!JavaNDUtils.shapesEqual(shape, that.shape)) {
return false;
}
int[] indices = new int[dim];
int index = 0;
// the whole method can be optimized when we have the default strides
for (int i = 0; i <= JavaNDUtils.arrayProduct(shape) - 2; i++) {
if (Math.abs(get(indices) - that.get(indices)) > tol) {
return false;
}
next(indices);
}
if (Math.abs(get(indices) - that.get(indices)) > tol) {
return false;
}
return true;
}
}
18 changes: 18 additions & 0 deletions src/main/scala/libs/CaffeNet.scala
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
package libs

import java.io._
import java.nio.file.{Paths, Files}

import org.apache.spark.sql.types._
import org.apache.spark.sql.{DataFrame, Row}
import org.bytedeco.javacpp.caffe._
Expand Down Expand Up @@ -146,6 +149,21 @@ class CaffeNet(netParam: NetParameter, schema: StructType, preprocessor: Preproc
}
}

def copyTrainedLayersFrom(filepath: String) = {
if (!Files.exists(Paths.get(filepath))) {
throw new IllegalArgumentException("The file " + filepath + " does not exist.\n")
}
caffeNet.CopyTrainedLayersFrom(filepath)
}

def saveWeightsToFile(filepath: String) = {
val f = new File(filepath)
f.getParentFile.mkdirs
val netParam = new NetParameter()
caffeNet.ToProto(netParam)
WriteProtoToBinaryFile(netParam, filepath)
}

def outputSchema(): StructType = {
val fields = Array.range(0, numOutputs).map(i => {
val output = caffeNet.blob_names().get(caffeNet.output_blob_indices().get(i)).getString
Expand Down
4 changes: 4 additions & 0 deletions src/main/scala/libs/NDArray.scala
Original file line number Diff line number Diff line change
Expand Up @@ -64,4 +64,8 @@ object NDArray {
v.add(v2)
v
}

def checkEqual(v1: NDArray, v2: NDArray, tol: Float): Boolean = {
return v1.javaArray.equals(v2.javaArray, tol)
}
}
21 changes: 21 additions & 0 deletions src/main/scala/libs/WeightCollection.scala
Original file line number Diff line number Diff line change
Expand Up @@ -36,4 +36,25 @@ object WeightCollection extends java.io.Serializable {
}
return new WeightCollection(newWeights, layerNames)
}

def checkEqual(wc1: WeightCollection, wc2: WeightCollection, tol: Float): Boolean = {
assert(wc1.layerNames == wc2.layerNames)
val layerNames = wc1.layerNames
//check that the WeightCollection objects have the same shape
for (i <- 0 to wc1.numLayers - 1) {
assert(wc1.allWeights(layerNames(i)).length == wc2.allWeights(layerNames(i)).length)
for (j <- 0 to wc1.allWeights(layerNames(i)).length - 1) {
assert(wc1.allWeights(layerNames(i))(j).shape.deep == wc2.allWeights(layerNames(i))(j).shape.deep)
}
}
// check that the weights are equal
for (i <- 0 to wc1.numLayers - 1) {
for (j <- 0 to wc1.allWeights(wc1.layerNames(i)).length - 1) {
if (!NDArray.checkEqual(wc1.allWeights(layerNames(i))(j), wc2.allWeights(layerNames(i))(j), tol)) {
return false
}
}
}
return true
}
}
40 changes: 13 additions & 27 deletions src/test/scala/libs/CaffeNetSpec.scala
Original file line number Diff line number Diff line change
Expand Up @@ -53,20 +53,7 @@ class CaffeNetSpec extends FlatSpec {
val weightsBefore = net.getWeights()
val outputs = net.forward(inputs.iterator)
val weightsAfter = net.getWeights()

// check that the weights are unchanged
assert(weightsBefore.layerNames == weightsAfter.layerNames)
val layerNames = weightsBefore.layerNames
for (i <- 0 to weightsBefore.numLayers - 1) {
assert(weightsBefore.allWeights(layerNames(i)).length == weightsAfter.allWeights(layerNames(i)).length)
for (j <- 0 to weightsBefore.allWeights(layerNames(i)).length - 1) {
val weightBefore = weightsBefore.allWeights(layerNames(i))(j).toFlat
val weightAfter = weightsAfter.allWeights(layerNames(i))(j).toFlat
for (k <- 0 to weightBefore.length - 1) {
assert((weightBefore(k) - weightAfter(k)).abs <= 1e-6)
}
}
}
assert(WeightCollection.checkEqual(weightsBefore, weightsAfter, 1e-10F)) // weights should be equal
}

"Calling forwardBackward" should "leave weights unchanged" in {
Expand All @@ -78,19 +65,18 @@ class CaffeNetSpec extends FlatSpec {
val weightsBefore = net.getWeights()
net.forwardBackward(inputs.iterator)
val weightsAfter = net.getWeights()
assert(WeightCollection.checkEqual(weightsBefore, weightsAfter, 1e-10F)) // weights should be equal
}

// check that the weights are unchanged
assert(weightsBefore.layerNames == weightsAfter.layerNames)
val layerNames = weightsBefore.layerNames
for (i <- 0 to weightsBefore.numLayers - 1) {
assert(weightsBefore.allWeights(layerNames(i)).length == weightsAfter.allWeights(layerNames(i)).length)
for (j <- 0 to weightsBefore.allWeights(layerNames(i)).length - 1) {
val weightBefore = weightsBefore.allWeights(layerNames(i))(j).toFlat
val weightAfter = weightsAfter.allWeights(layerNames(i))(j).toFlat
for (k <- 0 to weightBefore.length - 1) {
assert((weightBefore(k) - weightAfter(k)).abs <= 1e-6)
}
}
}
"Saving and loading the weights" should "leave the weights unchanged" in {
val netParam = new NetParameter()
ReadProtoFromTextFileOrDie(sparkNetHome + "/models/cifar10/cifar10_quick_train_test.prototxt", netParam)
val schema = StructType(StructField("data", ArrayType(FloatType), false) :: StructField("label", IntegerType) :: Nil)
val net1 = CaffeNet(netParam, schema, new DefaultPreprocessor(schema))
net1.saveWeightsToFile(sparkNetHome + "/temp/cifar10.caffemodel")
val net2 = CaffeNet(netParam, schema, new DefaultPreprocessor(schema))
assert(!WeightCollection.checkEqual(net1.getWeights(), net2.getWeights(), 1e-10F)) // weights should not be equal
net2.copyTrainedLayersFrom(sparkNetHome + "/temp/cifar10.caffemodel")
assert(WeightCollection.checkEqual(net1.getWeights(), net2.getWeights(), 1e-10F)) // weights should be equal
}
}

0 comments on commit b811223

Please sign in to comment.