diff --git a/lib/matrix.js b/lib/matrix.js index 34b2d97..dd396eb 100644 --- a/lib/matrix.js +++ b/lib/matrix.js @@ -9,13 +9,7 @@ class Matrix { } copy() { - let m = new Matrix(this.rows, this.cols); - for (let i = 0; i < this.rows; i++) { - for (let j = 0; j < this.cols; j++) { - m.data[i][j] = this.data[i][j]; - } - } - return m; + return Matrix.deserialize(this.serialize()); } static fromArray(arr) { diff --git a/lib/matrix.test.js b/lib/matrix.test.js index 4933895..49fc653 100644 --- a/lib/matrix.test.js +++ b/lib/matrix.test.js @@ -363,6 +363,7 @@ test('static map with row and column params', () => { ] }); }); + test('matrix (de)serialization', () => { let m = new Matrix(5, 5); m.randomize(); @@ -375,3 +376,16 @@ test('matrix (de)serialization', () => { data: m.data }); }); + +test('matrix copy', () => { + let m = new Matrix(5, 5); + m.randomize(); + + let n = m.copy(); + + expect(n).toEqual({ + rows: m.rows, + cols: m.cols, + data: m.data + }); +}); diff --git a/lib/nn.js b/lib/nn.js index 0260280..ba9124c 100644 --- a/lib/nn.js +++ b/lib/nn.js @@ -19,39 +19,22 @@ let tanh = new ActivationFunction( class NeuralNetwork { - // TODO: document what a, b, c are - constructor(a, b, c) { - if (a instanceof NeuralNetwork) { - this.input_nodes = a.input_nodes; - this.hidden_nodes = a.hidden_nodes; - this.output_nodes = a.output_nodes; - - this.weights_ih = a.weights_ih.copy(); - this.weights_ho = a.weights_ho.copy(); - - this.bias_h = a.bias_h.copy(); - this.bias_o = a.bias_o.copy(); - } else { - this.input_nodes = a; - this.hidden_nodes = b; - this.output_nodes = c; - - this.weights_ih = new Matrix(this.hidden_nodes, this.input_nodes); - this.weights_ho = new Matrix(this.output_nodes, this.hidden_nodes); - this.weights_ih.randomize(); - this.weights_ho.randomize(); - - this.bias_h = new Matrix(this.hidden_nodes, 1); - this.bias_o = new Matrix(this.output_nodes, 1); - this.bias_h.randomize(); - this.bias_o.randomize(); - } - - // TODO: copy these as well + constructor(input_nodes, hidden_nodes, output_nodes) { + this.input_nodes = input_nodes; + this.hidden_nodes = hidden_nodes; + this.output_nodes = output_nodes; + + this.weights_ih = new Matrix(this.hidden_nodes, this.input_nodes); + this.weights_ho = new Matrix(this.output_nodes, this.hidden_nodes); + this.weights_ih.randomize(); + this.weights_ho.randomize(); + + this.bias_h = new Matrix(this.hidden_nodes, 1); + this.bias_o = new Matrix(this.output_nodes, 1); + this.bias_h.randomize(); + this.bias_o.randomize(); this.setLearningRate(); this.setActivationFunction(); - - } predict(input_array) { @@ -158,7 +141,7 @@ class NeuralNetwork { // Adding function for neuro-evolution copy() { - return new NeuralNetwork(this); + return NeuralNetwork.deserialize(this.serialize()); } mutate(rate) {