Skip to content

Commit

Permalink
Refactoring
Browse files Browse the repository at this point in the history
* Simplified example (Main class)
* Neural network can now take arrays as inputs instead of Matrices
* modified readme
* outsource sigmoid to own class with static methods
  • Loading branch information
Kim Feichtinger authored and Kim Feichtinger committed Mar 5, 2018
1 parent c4c2347 commit a0b79ad
Show file tree
Hide file tree
Showing 4 changed files with 101 additions and 98 deletions.
15 changes: 14 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
@@ -1 +1,14 @@
# Basic Neural Network
# Basic Neural Network

This is a very basic Java Neural Network library using EJML (Efficient Java Matrix Library) based on an example by Daniel Shiffman.

## Features

- Neural Network with variable amounts of Inputs, hidden nodes and outputs
- Only one hidden layer supported (yet)
- Maven Dependency manager
- ...

## TODO

- add support for multy-layered Neural Networks
100 changes: 34 additions & 66 deletions src/main/java/Main.java
Original file line number Diff line number Diff line change
@@ -1,83 +1,51 @@
import org.ejml.simple.SimpleMatrix;

import java.util.Random;

/**
* Created by KimFeichtinger on 04.03.18.
*/
public class Main {

// Example program to test the basic neural network by letting it solve XOR
public static void main(String args[]){

SimpleMatrix[] trainingData = new SimpleMatrix[4];
for (int i = 0; i < trainingData.length; i++) {
trainingData[i] = new SimpleMatrix(2,1);
}

trainingData[0].set(0,0,0);
trainingData[0].set(1,0,0);

trainingData[1].set(0,0,1);
trainingData[1].set(1,0,1);

trainingData[2].set(0,0,1);
trainingData[2].set(1,0,0);

trainingData[3].set(0,0,0);
trainingData[3].set(1,0,1);

SimpleMatrix[] trainingDataTargets = new SimpleMatrix[4];
for (int i = 0; i < trainingDataTargets.length; i++) {
trainingDataTargets[i] = new SimpleMatrix(1,1);
}

trainingDataTargets[0].set(0,0,0);

trainingDataTargets[1].set(0,0,0);

trainingDataTargets[2].set(0,0,1);

trainingDataTargets[3].set(0,0,1);

SimpleMatrix testData1 = new SimpleMatrix(2,1);
testData1.set(0,0,0);
testData1.set(1,0,0);

SimpleMatrix testData2 = new SimpleMatrix(2,1);
testData2.set(0,0,1);
testData2.set(1,0,1);

SimpleMatrix testData3 = new SimpleMatrix(2,1);
testData3.set(0,0,0);
testData3.set(1,0,1);

SimpleMatrix testData4 = new SimpleMatrix(2,1);
testData4.set(0,0,1);
testData4.set(1,0,0);

// // IMPORTANT: inputs has to be a one column matrix!
// SimpleMatrix inputs = new SimpleMatrix(2,1);
// inputs.set(0.24);
//
// // correct answers for the given input
// SimpleMatrix targets = new SimpleMatrix(1,1);
// targets.set(0.1);
Random r = new Random();

// Training Data
double[][] trainingData = {
{0, 0},
{0, 1},
{1, 0},
{1, 1},
};

double[][] trainingDataTargets = {
{0},
{1},
{1},
{0},
};

// Testing Data
double[][] testingData = {
{0, 0},
{0, 1},
{1, 0},
{1, 1},
};

NeuralNetwork nn = new NeuralNetwork(2,2, 1);

//System.out.println(nn.feedForward(inputs));

for (int i = 0; i < 100000; i++) {
for (int j = 0; j < trainingData.length; j++) {
nn.train(trainingData[j], trainingDataTargets[j]);
}
// training
for (int i = 0; i < 50000; i++) {
// training in random order
int random = r.nextInt(4);
nn.train(trainingData[random], trainingDataTargets[random]);
}

System.out.println("0,0: " + nn.feedForward(testData1));
System.out.println("1,1: " + nn.feedForward(testData2));
System.out.println("0,1: " + nn.feedForward(testData3));
System.out.println("1,0: " + nn.feedForward(testData4));
// nn.train(inputs, targets);
// testing the nn
for (int i = 0; i < testingData.length; i++) {
System.out.println("Guess for " + testingData[i][0] + ", " + testingData[i][1] + ": \n" + nn.feedForward(testingData[i]));
}
}


Expand Down
48 changes: 17 additions & 31 deletions src/main/java/NeuralNetwork.java
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import org.ejml.simple.SimpleMatrix;
import utilities.Sigmoid;

import java.util.Arrays;
import java.util.Random;

/**
Expand All @@ -11,7 +13,7 @@ public class NeuralNetwork {

Random r = new Random();

// variables which store the "size" of the neural network
// "size" of the neural network
private int inputNodes;
private int hiddenNodes;
private int outputNodes;
Expand Down Expand Up @@ -39,13 +41,20 @@ public NeuralNetwork(int inputNodes, int hiddenNodes, int outputNodes){
}

// feedForward method, input is a one column matrix with the input values
public SimpleMatrix feedForward(SimpleMatrix inputs){
public SimpleMatrix feedForward(double[] i){
// transform array to matrix
SimpleMatrix inputs = arrayToMatrix(i);

SimpleMatrix hidden = calculateLayer(weightsIH, biasH, inputs);
SimpleMatrix output = calculateLayer(weightsHO, biasO, hidden);
return output;
}

public void train(SimpleMatrix inputs, SimpleMatrix targets){
public void train(double[] i, double[] t){
// transform 2d array to matrix
SimpleMatrix inputs = arrayToMatrix(i);
SimpleMatrix targets = arrayToMatrix(t);

// calculate outputs of hidden and output layer for the given inputs
SimpleMatrix hidden = calculateLayer(weightsIH, biasH, inputs);
SimpleMatrix outputs = calculateLayer(weightsHO, biasO, hidden);
Expand All @@ -70,21 +79,19 @@ public void train(SimpleMatrix inputs, SimpleMatrix targets){
biasH = biasH.plus(hiddenGradient);
}

// ***** Helping methods: *****

// generic function to calculate one layer
private SimpleMatrix calculateLayer(SimpleMatrix weights, SimpleMatrix bias, SimpleMatrix input){
// calculate outputs of layer
SimpleMatrix result = weights.mult(input);
// add bias to outputs
result = result.plus(bias);
// apply activation function and return result
result = applySigmoid(result, false);
result = Sigmoid.applySigmoid(result, false);
return result;
}

private SimpleMatrix calculateGradient(SimpleMatrix layer, SimpleMatrix error){
SimpleMatrix gradient = applySigmoid(layer, true);
SimpleMatrix gradient = Sigmoid.applySigmoid(layer, true);
gradient = gradient.elementMult(error);
return gradient.scale(LEARNING_RATE);
}
Expand All @@ -93,29 +100,8 @@ private SimpleMatrix calculateDeltas(SimpleMatrix gradient, SimpleMatrix layer){
return gradient.mult(layer.transpose());
}

private SimpleMatrix applySigmoid(SimpleMatrix input, boolean derivative){
SimpleMatrix output = new SimpleMatrix(input.numRows(), input.numCols());
for (int i = 0; i < input.numRows(); i++) {
for (int j = 0; j < input.numCols(); j++) {
double value = input.get(i, j);
// apply dsigmoid if derivative = true, otherwise usual sigmoid
if(derivative){
output.set(i, j, dsigmoid(value));
}else {
output.set(i, j, sigmoid(value));
}
}
}

return output;
}

private double sigmoid(double input){
return 1 / (1 + Math.exp(-input));
}

// derivative of sigmoid (not real derivative because sigmoid function has already been applied to the input)
private double dsigmoid(double input){
return input * (1 - input);
private SimpleMatrix arrayToMatrix(double[] i){
double[][] input = {i};
return new SimpleMatrix(input).transpose();
}
}
36 changes: 36 additions & 0 deletions src/main/java/utilities/Sigmoid.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
package utilities;

import org.ejml.simple.SimpleMatrix;

/**
* Created by KimFeichtinger on 05.03.18.
*/
public class Sigmoid {

public static SimpleMatrix applySigmoid(SimpleMatrix input, boolean derivative){
SimpleMatrix output = new SimpleMatrix(input.numRows(), input.numCols());
for (int i = 0; i < input.numRows(); i++) {
for (int j = 0; j < input.numCols(); j++) {
double value = input.get(i, j);
// apply dsigmoid if derivative = true, otherwise usual Sigmoid
if(derivative){
output.set(i, j, dsigmoid(value));
}else {
output.set(i, j, sigmoid(value));
}
}
}

return output;
}

private static double sigmoid(double input){
return 1 / (1 + Math.exp(-input));
}

// derivative of Sigmoid (not real derivative because Sigmoid function has already been applied to the input)
private static double dsigmoid(double input){
return input * (1 - input);
}

}

0 comments on commit a0b79ad

Please sign in to comment.