Skip to content

Commit

Permalink
Added saving and reading nn-data to/from JSON-file
Browse files Browse the repository at this point in the history
  • Loading branch information
Kim Feichtinger authored and Kim Feichtinger committed Mar 7, 2018
1 parent dd26e18 commit 9e49497
Show file tree
Hide file tree
Showing 5 changed files with 151 additions and 30 deletions.
3 changes: 2 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
.idea
target
out
artifacts
artifacts
nn_data.json
29 changes: 18 additions & 11 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,32 +12,39 @@ If you want to learn more about Neural Networks check out these YouTube-playlist

- Neural Network with variable amounts of inputs, hidden nodes and outputs
- Two layered (hidden + output)
- [Maven](https://maven.apache.org) Build-Management
- Save the weights and biases of a NN to a JSON-file
- Read a JSON-file with NN-data
- [EJML](https://www.ejml.org) used for Matrix math

## Examples (Usages)

- [XOR solved with Basic Neural Network Library](https://github.com/kim-marcel/xor_with_nn)

## Download

If you want to use this library you can [download](https://github.com/kim-marcel/basic_neural_network/releases/download/v0.1-alpha/basic_neural_network-v0.1-alpha.jar) v0.1-alpha here or check the release tab of this repository. There might be a newer version available.
- [Maven](https://maven.apache.org) Build-Management

## Code example

```java
// Neural Network with 2 inputs, 4 hidden nodes and 1 output
NeuralNetwork nn = new NeuralNetwork(2, 4, 1);

// Reads from a (previously generated) JSON-file the weights and biases of the NN
nn.readFromFile();

// Train the Neural Network with a training dataset
nn.train(trainingDataInputs, trainingDataTargets);

// Guess for the given testing data is returned as a 2D array (double[][])
nn.guess(testingData);

// Writes a JSON-file with the current "state" (weights and biases) of the NN
nn.writeToFile();
```
For a more detailed example check out [this](https://github.com/kim-marcel/xor_with_nn) repository.
A more detailed example cam be found below.

## Download

If you want to use this library you can download [v0.1-alpha](https://github.com/kim-marcel/basic_neural_network/releases/download/v0.1-alpha/basic_neural_network-v0.1-alpha.jar) here or check the release tab of this repository.

## Examples

- [XOR solved with Basic Neural Network Library](https://github.com/kim-marcel/xor_with_nn)

## Upcoming features

- Support for multiple layers
- Save and load the neural network to a file
19 changes: 19 additions & 0 deletions pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,19 @@
<artifactId>basic_neural_network</artifactId>
<version>v0.1-alpha</version>

<build>
<plugins>
<plugin>
<groupId>org.apache.maven.plugins</groupId>
<artifactId>maven-compiler-plugin</artifactId>
<configuration>
<source>1.8</source>
<target>1.8</target>
</configuration>
</plugin>
</plugins>
</build>

<dependencies>

<dependency>
Expand All @@ -16,6 +29,12 @@
<version>0.33</version>
</dependency>

<dependency>
<groupId>com.googlecode.json-simple</groupId>
<artifactId>json-simple</artifactId>
<version>1.1.1</version>
</dependency>

</dependencies>


Expand Down
73 changes: 55 additions & 18 deletions src/main/java/NeuralNetwork.java
Original file line number Diff line number Diff line change
@@ -1,7 +1,11 @@
import org.ejml.simple.SimpleMatrix;
import org.json.simple.JSONObject;
import org.json.simple.parser.JSONParser;
import org.json.simple.parser.ParseException;
import utilities.MatrixConverter;
import utilities.Sigmoid;

import java.util.Arrays;
import java.io.*;
import java.util.Random;

/**
Expand All @@ -21,11 +25,11 @@ public class NeuralNetwork {
private SimpleMatrix weightsIH; // matrix with weights between input and hidden layer
private SimpleMatrix weightsHO; // matrix with weights between hidden and output layer

private SimpleMatrix biasH;
private SimpleMatrix biasO;
private SimpleMatrix biasH; // bias of the hidden layer
private SimpleMatrix biasO; // bias of the output layer

// Constructor
// generate a new neural network with 1 hidden layer with the given amount of nodes in the layers
// generate a new neural network with 1 hidden layer with the given amount of nodes in the individual layers
public NeuralNetwork(int inputNodes, int hiddenNodes, int outputNodes){
this.inputNodes = inputNodes;
this.hiddenNodes = hiddenNodes;
Expand All @@ -43,17 +47,17 @@ public NeuralNetwork(int inputNodes, int hiddenNodes, int outputNodes){
// guess method, input is a one column matrix with the input values
public double[][] guess(double[] i){
// transform array to matrix
SimpleMatrix inputs = arrayToMatrix(i);
SimpleMatrix inputs = MatrixConverter.arrayToMatrix(i);

SimpleMatrix hidden = calculateLayer(weightsIH, biasH, inputs);
SimpleMatrix output = calculateLayer(weightsHO, biasO, hidden);
return matrixToArray(output);
return MatrixConverter.matrixToArray(output);
}

public void train(double[] i, double[] t){
// transform 2d array to matrix
SimpleMatrix inputs = arrayToMatrix(i);
SimpleMatrix targets = arrayToMatrix(t);
SimpleMatrix inputs = MatrixConverter.arrayToMatrix(i);
SimpleMatrix targets = MatrixConverter.arrayToMatrix(t);

// calculate outputs of hidden and output layer for the given inputs
SimpleMatrix hidden = calculateLayer(weightsIH, biasH, inputs);
Expand Down Expand Up @@ -100,19 +104,52 @@ private SimpleMatrix calculateDeltas(SimpleMatrix gradient, SimpleMatrix layer){
return gradient.mult(layer.transpose());
}

private SimpleMatrix arrayToMatrix(double[] i){
double[][] input = {i};
return new SimpleMatrix(input).transpose();
public void writeToFile(){
JSONObject obj = new JSONObject();
obj.put("weightsIH", MatrixConverter.matrixToJSON(weightsIH));
obj.put("weightsHO", MatrixConverter.matrixToJSON(weightsHO));
obj.put("biasH", MatrixConverter.matrixToJSON(biasH));
obj.put("biasO", MatrixConverter.matrixToJSON(biasO));

try (FileWriter file = new FileWriter("nn_data.json")) {

file.write(obj.toJSONString());
file.flush();

} catch (IOException e) {
e.printStackTrace();
}

}

public double[][] matrixToArray(SimpleMatrix i){
double[][] result = new double[i.numRows()][i.numCols()];
for (int j = 0; j < result.length; j++) {
for (int k = 0; k < result[0].length; k++) {
result[j][k] = i.get(j, k);
}
public void readFromFile(){
JSONParser parser = new JSONParser();

try {

Object obj = parser.parse(new FileReader("nn_data.json"));

JSONObject nnData = (JSONObject) obj;

JSONObject weightsIHJSON = (JSONObject) nnData.get("weightsIH");
JSONObject weightsHOJSON = (JSONObject) nnData.get("weightsHO");

weightsIH = MatrixConverter.jsonToMatrix(weightsIHJSON);
weightsHO = MatrixConverter.jsonToMatrix(weightsHOJSON);

JSONObject biasHJSON = (JSONObject) nnData.get("biasH");
JSONObject biasOJSON = (JSONObject) nnData.get("biasO");

biasH = MatrixConverter.jsonToMatrix(biasHJSON);
biasO = MatrixConverter.jsonToMatrix(biasOJSON);

} catch (FileNotFoundException e) {
e.printStackTrace();
} catch (IOException e) {
e.printStackTrace();
} catch (ParseException e) {
e.printStackTrace();
}
return result;
}

public double getLearningRate() {
Expand Down
57 changes: 57 additions & 0 deletions src/main/java/utilities/MatrixConverter.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
package utilities;

import org.ejml.simple.SimpleMatrix;
import org.json.simple.JSONObject;

/**
* Created by KimFeichtinger on 07.03.18.
*/
public class MatrixConverter {

public static SimpleMatrix arrayToMatrix(double[] i){
double[][] input = {i};
return new SimpleMatrix(input).transpose();
}

public static double[][] matrixToArray(SimpleMatrix i){
double[][] result = new double[i.numRows()][i.numCols()];
for (int j = 0; j < result.length; j++) {
for (int k = 0; k < result[0].length; k++) {
result[j][k] = i.get(j, k);
}
}
return result;
}

public static SimpleMatrix jsonToMatrix(JSONObject i){
int rows = i.size();
int cols = ((JSONObject) i.get(Integer.toString(0))).size();

SimpleMatrix result = new SimpleMatrix(rows, cols);

for (int j = 0; j < i.size(); j++) {
JSONObject js = (JSONObject) i.get(Integer.toString(j));
for (int k = 0; k < js.size(); k++) {
double d = (double) js.get(Integer.toString(k));
result.set(j, k, d);
}
}

return result;
}

public static JSONObject matrixToJSON(SimpleMatrix i){
JSONObject result = new JSONObject();
JSONObject a = new JSONObject();

for (int j = 0; j < i.numRows(); j++) {
a.clear();
for (int k = 0; k < i.numCols(); k++) {
a.put(k, i.get(j, k));
}
result.put(j, new JSONObject(a));
}

return result;
}
}

0 comments on commit 9e49497

Please sign in to comment.