Last active
November 24, 2022 19:58
-
-
Save duongddinh/47e5d6e7b9630e1183f1565a4a86e648 to your computer and use it in GitHub Desktop.
Simple neural network example written in java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import java.util.Random; | |
public class Neural { | |
static double weights[][] = new double[3][1]; | |
static double output[][] = new double[4][1]; | |
public static void main(String args[]) { | |
double input[][] = {{ 0, 0, 1 }, {1, 1, 1}, { 1,0,1 }, {0,1,1}}; | |
double trueoutput[][] = { {0}, {1}, {1}, {0}}; | |
double testoutput[][] = { {0, 1, 0}, {1,0,0} , { 0,0,0 }, {1,1,0}}; | |
GenWeights(); | |
for (int i=0; i < 1000; i++) | |
train(input, trueoutput); | |
application(testoutput); | |
for(int i=0; i< output.length; i++) { | |
for(int j=0; j< output[i].length; j++) { | |
System.out.print(output[i][j] + " "); | |
} | |
System.out.println(); | |
} | |
} | |
static void application(double[][] input) { | |
for(int i=0; i< output.length; i++) { | |
for(int j=0; j< output[i].length; j++) { | |
output[i][j] = sigmoid(dot(input, weights)[i][j]); | |
} | |
} | |
} | |
static void train(double[][] input, double trueoutput[][]) { | |
double[][] error = new double[output.length][output[0].length]; | |
double[][] adjustment = new double[weights.length][weights[0].length]; | |
double[][] Errchange = new double[output.length][output[0].length]; | |
for(int i=0; i< output.length; i++) { | |
for(int j=0; j< output[i].length; j++) | |
output[i][j] = sigmoid(dot(input, weights)[i][j]); | |
} | |
for(int i=0; i< error.length; i++) { | |
for(int j=0; j< error[i].length; j++) | |
error[i][j] = trueoutput[i][j] - output[i][j]; | |
} | |
for(int i=0; i< Errchange.length; i++) { | |
for(int j=0; j< Errchange[i].length; j++) | |
Errchange[i][j] = error[i][j]*deriva_sigmoid(output[i][j]); | |
} | |
adjustment = dot(T(input), Errchange); | |
for(int i=0; i< weights.length; i++) { | |
for(int j=0; j< weights[i].length; j++) | |
weights[i][j] += adjustment[i][j]; | |
} | |
} | |
static void GenWeights() { | |
Random rand = new Random(8); | |
for (int i=0; i < weights.length; i++) { | |
for(int j=0; j< weights[i].length; j++) { | |
weights[i][j] = rand.nextDouble(); | |
} | |
} | |
} | |
static double sigmoid(double d) { | |
return 1 / (1 + Math.pow(Math.E,(-d))); | |
} | |
static double deriva_sigmoid(double x) { | |
return x * (1 - x); | |
} | |
static double[][] T(double[][] a){ | |
double[][] c = new double[a[0].length][a.length]; | |
for (int i =0; i < a[0].length; i++) | |
for (int j = 0; j < a.length; j++) { | |
c[i][j] = a[j][i]; | |
} | |
return c; | |
} | |
static double[][] dot(double[][] arr, double[][] arr2){ | |
double[][] c = new double[arr.length][arr2[0].length]; | |
for (int x =0; x < arr.length; x ++) | |
for (int i =0; i < arr2[0].length; i++) { | |
for (int j =0; j < arr2.length; j++) { | |
c[x][i] += arr[x][j]*arr2[j][i]; | |
} | |
} | |
return c; | |
} | |
} |
Author
duongddinh
commented
Nov 24, 2022
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment