diff --git a/README.md b/README.md index 15fdea7..88c049a 100644 --- a/README.md +++ b/README.md @@ -1 +1,14 @@ -# Basic Neural Network \ No newline at end of file +# Basic Neural Network + +This is a very basic Java Neural Network library using EJML (Efficient Java Matrix Library) based on an example by Daniel Shiffman. + +## Features + +- Neural Network with variable amounts of Inputs, hidden nodes and outputs +- Only one hidden layer supported (yet) +- Maven Dependency manager +- ... + +## TODO + +- add support for multy-layered Neural Networks \ No newline at end of file diff --git a/src/main/java/Main.java b/src/main/java/Main.java index f047216..78b472c 100644 --- a/src/main/java/Main.java +++ b/src/main/java/Main.java @@ -1,5 +1,3 @@ -import org.ejml.simple.SimpleMatrix; - import java.util.Random; /** @@ -7,77 +5,47 @@ */ public class Main { + // Example program to test the basic neural network by letting it solve XOR public static void main(String args[]){ - SimpleMatrix[] trainingData = new SimpleMatrix[4]; - for (int i = 0; i < trainingData.length; i++) { - trainingData[i] = new SimpleMatrix(2,1); - } - - trainingData[0].set(0,0,0); - trainingData[0].set(1,0,0); - - trainingData[1].set(0,0,1); - trainingData[1].set(1,0,1); - - trainingData[2].set(0,0,1); - trainingData[2].set(1,0,0); - - trainingData[3].set(0,0,0); - trainingData[3].set(1,0,1); - - SimpleMatrix[] trainingDataTargets = new SimpleMatrix[4]; - for (int i = 0; i < trainingDataTargets.length; i++) { - trainingDataTargets[i] = new SimpleMatrix(1,1); - } - - trainingDataTargets[0].set(0,0,0); - - trainingDataTargets[1].set(0,0,0); - - trainingDataTargets[2].set(0,0,1); - - trainingDataTargets[3].set(0,0,1); - - SimpleMatrix testData1 = new SimpleMatrix(2,1); - testData1.set(0,0,0); - testData1.set(1,0,0); - - SimpleMatrix testData2 = new SimpleMatrix(2,1); - testData2.set(0,0,1); - testData2.set(1,0,1); - - SimpleMatrix testData3 = new SimpleMatrix(2,1); - testData3.set(0,0,0); - testData3.set(1,0,1); - - SimpleMatrix testData4 = new SimpleMatrix(2,1); - testData4.set(0,0,1); - testData4.set(1,0,0); - -// // IMPORTANT: inputs has to be a one column matrix! -// SimpleMatrix inputs = new SimpleMatrix(2,1); -// inputs.set(0.24); -// -// // correct answers for the given input -// SimpleMatrix targets = new SimpleMatrix(1,1); -// targets.set(0.1); + Random r = new Random(); + + // Training Data + double[][] trainingData = { + {0, 0}, + {0, 1}, + {1, 0}, + {1, 1}, + }; + + double[][] trainingDataTargets = { + {0}, + {1}, + {1}, + {0}, + }; + + // Testing Data + double[][] testingData = { + {0, 0}, + {0, 1}, + {1, 0}, + {1, 1}, + }; NeuralNetwork nn = new NeuralNetwork(2,2, 1); - //System.out.println(nn.feedForward(inputs)); - - for (int i = 0; i < 100000; i++) { - for (int j = 0; j < trainingData.length; j++) { - nn.train(trainingData[j], trainingDataTargets[j]); - } + // training + for (int i = 0; i < 50000; i++) { + // training in random order + int random = r.nextInt(4); + nn.train(trainingData[random], trainingDataTargets[random]); } - System.out.println("0,0: " + nn.feedForward(testData1)); - System.out.println("1,1: " + nn.feedForward(testData2)); - System.out.println("0,1: " + nn.feedForward(testData3)); - System.out.println("1,0: " + nn.feedForward(testData4)); -// nn.train(inputs, targets); + // testing the nn + for (int i = 0; i < testingData.length; i++) { + System.out.println("Guess for " + testingData[i][0] + ", " + testingData[i][1] + ": \n" + nn.feedForward(testingData[i])); + } } diff --git a/src/main/java/NeuralNetwork.java b/src/main/java/NeuralNetwork.java index 0491b27..f3acbe3 100644 --- a/src/main/java/NeuralNetwork.java +++ b/src/main/java/NeuralNetwork.java @@ -1,5 +1,7 @@ import org.ejml.simple.SimpleMatrix; +import utilities.Sigmoid; +import java.util.Arrays; import java.util.Random; /** @@ -11,7 +13,7 @@ public class NeuralNetwork { Random r = new Random(); - // variables which store the "size" of the neural network + // "size" of the neural network private int inputNodes; private int hiddenNodes; private int outputNodes; @@ -39,13 +41,20 @@ public NeuralNetwork(int inputNodes, int hiddenNodes, int outputNodes){ } // feedForward method, input is a one column matrix with the input values - public SimpleMatrix feedForward(SimpleMatrix inputs){ + public SimpleMatrix feedForward(double[] i){ + // transform array to matrix + SimpleMatrix inputs = arrayToMatrix(i); + SimpleMatrix hidden = calculateLayer(weightsIH, biasH, inputs); SimpleMatrix output = calculateLayer(weightsHO, biasO, hidden); return output; } - public void train(SimpleMatrix inputs, SimpleMatrix targets){ + public void train(double[] i, double[] t){ + // transform 2d array to matrix + SimpleMatrix inputs = arrayToMatrix(i); + SimpleMatrix targets = arrayToMatrix(t); + // calculate outputs of hidden and output layer for the given inputs SimpleMatrix hidden = calculateLayer(weightsIH, biasH, inputs); SimpleMatrix outputs = calculateLayer(weightsHO, biasO, hidden); @@ -70,8 +79,6 @@ public void train(SimpleMatrix inputs, SimpleMatrix targets){ biasH = biasH.plus(hiddenGradient); } - // ***** Helping methods: ***** - // generic function to calculate one layer private SimpleMatrix calculateLayer(SimpleMatrix weights, SimpleMatrix bias, SimpleMatrix input){ // calculate outputs of layer @@ -79,12 +86,12 @@ private SimpleMatrix calculateLayer(SimpleMatrix weights, SimpleMatrix bias, Sim // add bias to outputs result = result.plus(bias); // apply activation function and return result - result = applySigmoid(result, false); + result = Sigmoid.applySigmoid(result, false); return result; } private SimpleMatrix calculateGradient(SimpleMatrix layer, SimpleMatrix error){ - SimpleMatrix gradient = applySigmoid(layer, true); + SimpleMatrix gradient = Sigmoid.applySigmoid(layer, true); gradient = gradient.elementMult(error); return gradient.scale(LEARNING_RATE); } @@ -93,29 +100,8 @@ private SimpleMatrix calculateDeltas(SimpleMatrix gradient, SimpleMatrix layer){ return gradient.mult(layer.transpose()); } - private SimpleMatrix applySigmoid(SimpleMatrix input, boolean derivative){ - SimpleMatrix output = new SimpleMatrix(input.numRows(), input.numCols()); - for (int i = 0; i < input.numRows(); i++) { - for (int j = 0; j < input.numCols(); j++) { - double value = input.get(i, j); - // apply dsigmoid if derivative = true, otherwise usual sigmoid - if(derivative){ - output.set(i, j, dsigmoid(value)); - }else { - output.set(i, j, sigmoid(value)); - } - } - } - - return output; - } - - private double sigmoid(double input){ - return 1 / (1 + Math.exp(-input)); - } - - // derivative of sigmoid (not real derivative because sigmoid function has already been applied to the input) - private double dsigmoid(double input){ - return input * (1 - input); + private SimpleMatrix arrayToMatrix(double[] i){ + double[][] input = {i}; + return new SimpleMatrix(input).transpose(); } } diff --git a/src/main/java/utilities/Sigmoid.java b/src/main/java/utilities/Sigmoid.java new file mode 100644 index 0000000..2e8da65 --- /dev/null +++ b/src/main/java/utilities/Sigmoid.java @@ -0,0 +1,36 @@ +package utilities; + +import org.ejml.simple.SimpleMatrix; + +/** + * Created by KimFeichtinger on 05.03.18. + */ +public class Sigmoid { + + public static SimpleMatrix applySigmoid(SimpleMatrix input, boolean derivative){ + SimpleMatrix output = new SimpleMatrix(input.numRows(), input.numCols()); + for (int i = 0; i < input.numRows(); i++) { + for (int j = 0; j < input.numCols(); j++) { + double value = input.get(i, j); + // apply dsigmoid if derivative = true, otherwise usual Sigmoid + if(derivative){ + output.set(i, j, dsigmoid(value)); + }else { + output.set(i, j, sigmoid(value)); + } + } + } + + return output; + } + + private static double sigmoid(double input){ + return 1 / (1 + Math.exp(-input)); + } + + // derivative of Sigmoid (not real derivative because Sigmoid function has already been applied to the input) + private static double dsigmoid(double input){ + return input * (1 - input); + } + +}