From 1bbb6f5200d9b63fa98d7dd17c1e5c8742dc698f Mon Sep 17 00:00:00 2001 From: Kim Feichtinger Date: Mon, 5 Mar 2018 21:03:28 +0000 Subject: [PATCH] Removing main from project * main class was there to test the NN functionality by solving XOR --- .gitignore | 3 +- README.md | 2 +- src/main/java/Main.java | 52 -------------------------------- src/main/java/NeuralNetwork.java | 16 +++++++--- 4 files changed, 15 insertions(+), 58 deletions(-) delete mode 100644 src/main/java/Main.java diff --git a/.gitignore b/.gitignore index ec376bb..6f68e91 100644 --- a/.gitignore +++ b/.gitignore @@ -1,2 +1,3 @@ .idea -target \ No newline at end of file +target +out \ No newline at end of file diff --git a/README.md b/README.md index 88c049a..62fcd73 100644 --- a/README.md +++ b/README.md @@ -11,4 +11,4 @@ This is a very basic Java Neural Network library using EJML (Efficient Java Matr ## TODO -- add support for multy-layered Neural Networks \ No newline at end of file +- Add support for multy-layered Neural Net works \ No newline at end of file diff --git a/src/main/java/Main.java b/src/main/java/Main.java deleted file mode 100644 index 78b472c..0000000 --- a/src/main/java/Main.java +++ /dev/null @@ -1,52 +0,0 @@ -import java.util.Random; - -/** - * Created by KimFeichtinger on 04.03.18. - */ -public class Main { - - // Example program to test the basic neural network by letting it solve XOR - public static void main(String args[]){ - - Random r = new Random(); - - // Training Data - double[][] trainingData = { - {0, 0}, - {0, 1}, - {1, 0}, - {1, 1}, - }; - - double[][] trainingDataTargets = { - {0}, - {1}, - {1}, - {0}, - }; - - // Testing Data - double[][] testingData = { - {0, 0}, - {0, 1}, - {1, 0}, - {1, 1}, - }; - - NeuralNetwork nn = new NeuralNetwork(2,2, 1); - - // training - for (int i = 0; i < 50000; i++) { - // training in random order - int random = r.nextInt(4); - nn.train(trainingData[random], trainingDataTargets[random]); - } - - // testing the nn - for (int i = 0; i < testingData.length; i++) { - System.out.println("Guess for " + testingData[i][0] + ", " + testingData[i][1] + ": \n" + nn.feedForward(testingData[i])); - } - } - - -} diff --git a/src/main/java/NeuralNetwork.java b/src/main/java/NeuralNetwork.java index f3acbe3..394a409 100644 --- a/src/main/java/NeuralNetwork.java +++ b/src/main/java/NeuralNetwork.java @@ -1,7 +1,6 @@ import org.ejml.simple.SimpleMatrix; import utilities.Sigmoid; -import java.util.Arrays; import java.util.Random; /** @@ -9,10 +8,10 @@ */ public class NeuralNetwork { - private static final double LEARNING_RATE = 0.1; - Random r = new Random(); + private double learningRate = 0.1; + // "size" of the neural network private int inputNodes; private int hiddenNodes; @@ -93,7 +92,7 @@ private SimpleMatrix calculateLayer(SimpleMatrix weights, SimpleMatrix bias, Sim private SimpleMatrix calculateGradient(SimpleMatrix layer, SimpleMatrix error){ SimpleMatrix gradient = Sigmoid.applySigmoid(layer, true); gradient = gradient.elementMult(error); - return gradient.scale(LEARNING_RATE); + return gradient.scale(learningRate); } private SimpleMatrix calculateDeltas(SimpleMatrix gradient, SimpleMatrix layer){ @@ -104,4 +103,13 @@ private SimpleMatrix arrayToMatrix(double[] i){ double[][] input = {i}; return new SimpleMatrix(input).transpose(); } + + public double getLearningRate() { + return learningRate; + } + + public void setLearningRate(double learningRate) { + this.learningRate = learningRate; + } + }