Skip to content

Instantly share code, notes, and snippets.

@funrep
Created June 20, 2017 17:38
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save funrep/c2dfb3225c0eb5956aed722f776f2d24 to your computer and use it in GitHub Desktop.
Save funrep/c2dfb3225c0eb5956aed722f776f2d24 to your computer and use it in GitHub Desktop.
Neural net - not working
import java.util.ArrayList;
public class Layer {
private ArrayList<Neuron> nodes;
public Layer(int nodeCount, int inputCount) {
nodes = new ArrayList<>();
for (int i = 0; i < nodeCount; i++) {
nodes.add(new Neuron(inputCount));
}
}
public ArrayList<Neuron> getNodes() {
return nodes;
}
}
import java.util.ArrayList;
public class Main {
public static void main(String[] args) {
// Test by learning XOR boolean operator
double[][] sampleIn = { { 0, 0 }, { 0, 1}, { 1, 0 }, { 1, 1 } };
double[][] sampleOut = { { 0 }, { 1 }, { 1 }, { 0 } };
ArrayList<ArrayList<Double>> trainingInp = conv(sampleIn);
ArrayList<ArrayList<Double>> trainingOut = conv(sampleOut);
// Two input nodes, two nodes in hidden layer, 1 output node
int[] layers = { 2, 2, 1 };
Network net = new Network(layers, 0.01);
train(1000, net, trainingInp, trainingOut);
for (int i = 0; i < trainingInp.size(); i++) {
ArrayList<Double> res = net.runNetwork(trainingInp.get(i));
printArr(trainingInp.get(i));
System.out.print("net: ");
printArr(res);
System.out.print("sample: ");
printArr(trainingOut.get(i));
System.out.println();
}
}
public static void train(int maxEpoch, Network net,
ArrayList<ArrayList<Double>> trainingInp,
ArrayList<ArrayList<Double>> trainingOut) {
for (int epoch = 0; epoch < maxEpoch; epoch++) {
double error = 0;
for (int i = 0; i < trainingInp.size(); i++) {
ArrayList<Double> inputs = trainingInp.get(i);
ArrayList<Double> targets = trainingOut.get(i);
net.backprop(inputs, targets);
error += totalError(targets, net.lastOutput());
}
System.out.println("Error: " + error);
}
}
public static double totalError(ArrayList<Double> target, ArrayList<Double> output) {
double sum = 0;
for (int i = 0; i < target.size(); i++) {
sum += 0.5 * (Math.pow(target.get(i) - output.get(i), 2));
}
return sum;
}
public static ArrayList<ArrayList<Double>> conv(double[][] arr) {
ArrayList<ArrayList<Double>> newArr = new ArrayList<>(new ArrayList<>());
for (int i = 0; i < arr.length; i++) {
ArrayList<Double> inner = new ArrayList<>();
for (int j = 0; j < arr[i].length; j++) {
inner.add(arr[i][j]);
}
newArr.add(inner);
}
return newArr;
}
public static void printArr(ArrayList<Double> arrayList) {
for (Double n : arrayList) {
System.out.print(n + " ");
}
}
}
import java.util.ArrayList;
public class Network {
private ArrayList<Layer> layers;
private double learningRate;
public Network(int[] layerCounts, double learningRate) {
layers = new ArrayList<>();
for (int i = 1; i < layerCounts.length; i++) {
layers.add(new Layer(layerCounts[i], layerCounts[i - 1]));
}
this.learningRate = learningRate;
}
public void backprop(ArrayList<Double> inputs, ArrayList<Double> targets) {
feedForward(inputs);
calcErrors(targets);
backpropErrors(inputs);
}
public ArrayList<Double> runNetwork(ArrayList<Double> input) {
feedForward(input);
ArrayList<Double> result = new ArrayList<>();
for (Neuron n : layers.get(layers.size() - 1).getNodes()) {
result.add(n.getOutput());
}
return result;
}
public ArrayList<Double> lastOutput() {
ArrayList<Double> result = new ArrayList<>();
for (Neuron n : layers.get(layers.size() - 1).getNodes()) {
result.add(n.getOutput());
}
return result;
}
public void feedForward(ArrayList<Double> inputs) {
for (Neuron n : layers.get(0).getNodes()) {
n.calcOut(inputs);
}
for (int i = 1; i < layers.size(); i++) {
ArrayList<Double> prev = new ArrayList<>();
for (Neuron n : layers.get(i - 1).getNodes()) {
prev.add(n.getOutput());
}
for (Neuron n : layers.get(i).getNodes()) {
n.calcOut(prev);
}
}
}
public void calcErrors(ArrayList<Double> targets) {
// Calculate output layer errors
for (int i = 0; i < targets.size(); i++) {
double t = targets.get(i);
Neuron n = layers.get(layers.size() - 1).getNodes().get(i);
double o = n.getOutput();
n.setError((t - o) * o * (1 - o));
}
// Calculate hidden layers errors
for (int i = layers.size() - 2; i >= 0; i--) {
for (Neuron n : layers.get(i).getNodes()) {
double sum = 0;
for (double w : n.getWeights()) {
sum += w * n.getError();
}
double o = n.getOutput();
n.setError(o * (1 - o) * sum);
}
}
}
public void backpropErrors(ArrayList<Double> inputs) {
for (int i = layers.size() - 1; i >= 0; i--) {
for (int j = 0; j < layers.get(i).getNodes().size(); j++) {
Neuron n = layers.get(i).getNodes().get(j);
double biasDiff = learningRate * n.getError();
n.setBias(n.getBias() + biasDiff);
for (int k = 0; k < n.getWeights().size(); k++) {
double prevOut;
if (i == 0) {
prevOut = inputs.get(k);
} else {
prevOut = layers.get(i - 1).getNodes().get(k).getOutput();
}
double wDiff = learningRate * n.getError() * prevOut;
n.setWeight(k, n.getWeight(k) + wDiff);
}
}
}
}
}
import java.util.ArrayList;
import java.util.Random;
public class Neuron {
private ArrayList<Double> weights;
private double bias;
private double output;
private double error;
public double getOutput() {
return output;
}
public void setOutput(double output) {
this.output = output;
}
public double getError() {
return error;
}
public void setError(double error) {
this.error = error;
}
public double getBias() {
return bias;
}
public void setBias(double bias) {
this.bias = bias;
}
public ArrayList<Double> getWeights() {
return this.weights;
}
public double getWeight(int i) {
return this.weights.get(i);
}
public void setWeight(int i, double w) {
this.weights.set(i, w);
}
public Neuron(int weightCount) {
Random rnd = new Random();
weights = new ArrayList<>();
for (int i = 0; i < weightCount; i++) {
weights.add(rnd.nextDouble() * 2 - 1);
}
bias = rnd.nextDouble() * 2 - 1;
output = 0.0;
error = 0.0;
}
public double calcOut(ArrayList<Double> inputs) {
double out = bias;
for (int i = 0; i < weights.size(); i++) {
out += inputs.get(i) * weights.get(i);
}
return this.output = sigmoid(out);
}
public double sigmoid(double n) {
return 1 / (1 + Math.exp(-n));
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment