Skip to content

Instantly share code, notes, and snippets.

@jasebell
Created April 6, 2026 19:25
Show Gist options
  • Select an option

  • Save jasebell/011f3f60ab5062d0394e1bf66628dceb to your computer and use it in GitHub Desktop.

Select an option

Save jasebell/011f3f60ab5062d0394e1bf66628dceb to your computer and use it in GitHub Desktop.
SimpleNeuralNetwork Demo - All math, no libraries.
import java.util.Arrays;
import java.util.Random;
public class SimpleNeuralNetwork {
// Network parameters
private double[][] W1; // Input to hidden weights
private double[][] b1; // Hidden layer bias
private double[][] W2; // Hidden to output weights
private double[][] b2; // Output layer bias
// Cached values for backpropagation
private double[][] z1, a1, z2, a2;
private double[][] X_cache;
private int inputSize;
private int hiddenSize;
private int outputSize;
public SimpleNeuralNetwork(int inputSize, int hiddenSize, int outputSize) {
this.inputSize = inputSize;
this.hiddenSize = hiddenSize;
this.outputSize = outputSize;
System.out.println("=== INITIALIZATION ===");
// Initialize weights with small random values
Random rand = new Random(42);
W1 = new double[inputSize][hiddenSize];
for (int i = 0; i < inputSize; i++) {
for (int j = 0; j < hiddenSize; j++) {
W1[i][j] = (rand.nextGaussian() * 0.5);
}
}
System.out.println("Input to Hidden Weights (W1):");
printMatrix(W1);
b1 = new double[1][hiddenSize];
System.out.println("Hidden Layer Bias (b1):");
printMatrix(b1);
W2 = new double[hiddenSize][outputSize];
for (int i = 0; i < hiddenSize; i++) {
for (int j = 0; j < outputSize; j++) {
W2[i][j] = (rand.nextGaussian() * 0.5);
}
}
System.out.println("Hidden to Output Weights (W2):");
printMatrix(W2);
b2 = new double[1][outputSize];
System.out.println("Output Layer Bias (b2):");
printMatrix(b2);
System.out.println();
}
// Sigmoid activation function
private double sigmoid(double z) {
z = Math.max(-500, Math.min(500, z)); // Clip to prevent overflow
return 1.0 / (1.0 + Math.exp(-z));
}
// Sigmoid derivative
private double sigmoidDerivative(double z) {
double s = sigmoid(z);
return s * (1 - s);
}
// Apply sigmoid to entire matrix
private double[][] applySigmoid(double[][] matrix) {
double[][] result = new double[matrix.length][matrix[0].length];
for (int i = 0; i < matrix.length; i++) {
for (int j = 0; j < matrix[0].length; j++) {
result[i][j] = sigmoid(matrix[i][j]);
}
}
return result;
}
// Apply sigmoid derivative to entire matrix
private double[][] applySigmoidDerivative(double[][] matrix) {
double[][] result = new double[matrix.length][matrix[0].length];
for (int i = 0; i < matrix.length; i++) {
for (int j = 0; j < matrix[0].length; j++) {
result[i][j] = sigmoidDerivative(matrix[i][j]);
}
}
return result;
}
// Matrix multiplication
private double[][] matrixMultiply(double[][] A, double[][] B) {
int rowsA = A.length;
int colsA = A[0].length;
int colsB = B[0].length;
double[][] result = new double[rowsA][colsB];
for (int i = 0; i < rowsA; i++) {
for (int j = 0; j < colsB; j++) {
for (int k = 0; k < colsA; k++) {
result[i][j] += A[i][k] * B[k][j];
}
}
}
return result;
}
// Matrix addition (broadcasting for bias)
private double[][] matrixAdd(double[][] A, double[][] B) {
double[][] result = new double[A.length][A[0].length];
for (int i = 0; i < A.length; i++) {
for (int j = 0; j < A[0].length; j++) {
result[i][j] = A[i][j] + B[0][j]; // Broadcast B
}
}
return result;
}
// Matrix transpose
private double[][] transpose(double[][] matrix) {
int rows = matrix.length;
int cols = matrix[0].length;
double[][] result = new double[cols][rows];
for (int i = 0; i < rows; i++) {
for (int j = 0; j < cols; j++) {
result[j][i] = matrix[i][j];
}
}
return result;
}
// Element-wise multiplication
private double[][] elementWiseMultiply(double[][] A, double[][] B) {
double[][] result = new double[A.length][A[0].length];
for (int i = 0; i < A.length; i++) {
for (int j = 0; j < A[0].length; j++) {
result[i][j] = A[i][j] * B[i][j];
}
}
return result;
}
// Matrix subtraction
private double[][] matrixSubtract(double[][] A, double[][] B) {
double[][] result = new double[A.length][A[0].length];
for (int i = 0; i < A.length; i++) {
for (int j = 0; j < A[0].length; j++) {
result[i][j] = A[i][j] - B[i][j];
}
}
return result;
}
// Forward pass
public double[][] forwardPass(double[][] X, boolean showSteps) {
X_cache = X;
if (showSteps) {
System.out.println("=== FORWARD PASS ===");
System.out.print("Input (X): ");
printMatrix(X);
}
// Step 1: Input to Hidden Layer - z1 = X * W1 + b1
z1 = matrixAdd(matrixMultiply(X, W1), b1);
if (showSteps) {
System.out.println("\nHidden Layer Linear Combination (z1 = X·W1 + b1):");
printMatrix(z1);
}
// Step 2: Apply activation - a1 = sigmoid(z1)
a1 = applySigmoid(z1);
if (showSteps) {
System.out.println("\nHidden Layer Activation (a1 = sigmoid(z1)):");
printMatrix(a1);
}
// Step 3: Hidden to Output Layer - z2 = a1 * W2 + b2
z2 = matrixAdd(matrixMultiply(a1, W2), b2);
if (showSteps) {
System.out.println("\nOutput Layer Linear Combination (z2 = a1·W2 + b2):");
printMatrix(z2);
}
// Step 4: Apply activation - a2 = sigmoid(z2)
a2 = applySigmoid(z2);
if (showSteps) {
System.out.println("\nFinal Output (a2 = sigmoid(z2)):");
printMatrix(a2);
System.out.println();
}
return a2;
}
// Compute loss (Mean Squared Error)
public double computeLoss(double[][] yPred, double[][] yTrue, boolean showSteps) {
double loss = 0.0;
for (int i = 0; i < yPred.length; i++) {
for (int j = 0; j < yPred[0].length; j++) {
loss += 0.5 * Math.pow(yPred[i][j] - yTrue[i][j], 2);
}
}
if (showSteps) {
System.out.println("=== LOSS CALCULATION ===");
System.out.println("Mean Squared Error: 0.5 * (y_pred - y_true)²");
System.out.printf("Loss = %.6f\n\n", loss);
}
return loss;
}
// Backward pass (backpropagation)
public void backwardPass(double[][] X, double[][] yTrue, double learningRate, boolean showSteps) {
if (showSteps) {
System.out.println("=== BACKWARD PASS (BACKPROPAGATION) ===");
}
int m = X.length;
// Step 1: Compute output layer error
double[][] outputError = matrixSubtract(a2, yTrue);
if (showSteps) {
System.out.print("Output Error (dL/da2): ");
printMatrix(outputError);
}
// Step 2: Compute gradients for output layer
double[][] dz2 = elementWiseMultiply(outputError, applySigmoidDerivative(z2));
if (showSteps) {
System.out.print("Output Layer Gradient (dL/dz2): ");
printMatrix(dz2);
}
double[][] dW2 = scalarMultiply(matrixMultiply(transpose(a1), dz2), 1.0 / m);
double[][] db2 = scalarMultiply(sumRows(dz2), 1.0 / m);
// Step 3: Compute hidden layer error (backpropagate)
double[][] hiddenError = matrixMultiply(dz2, transpose(W2));
if (showSteps) {
System.out.print("\nHidden Layer Error: ");
printMatrix(hiddenError);
}
// Step 4: Compute gradients for hidden layer
double[][] dz1 = elementWiseMultiply(hiddenError, applySigmoidDerivative(z1));
if (showSteps) {
System.out.print("Hidden Layer Gradient (dL/dz1): ");
printMatrix(dz1);
}
double[][] dW1 = scalarMultiply(matrixMultiply(transpose(X), dz1), 1.0 / m);
double[][] db1 = scalarMultiply(sumRows(dz1), 1.0 / m);
// Step 5: Update weights and biases (GRADIENT DESCENT)
if (showSteps) {
System.out.printf("\n=== PARAMETER UPDATES (Learning Rate: %.1f) ===\n", learningRate);
System.out.print("Old W2: ");
printMatrix(W2);
}
W2 = matrixSubtract(W2, scalarMultiply(dW2, learningRate));
b2 = matrixSubtract(b2, scalarMultiply(db2, learningRate));
W1 = matrixSubtract(W1, scalarMultiply(dW1, learningRate));
b1 = matrixSubtract(b1, scalarMultiply(db1, learningRate));
if (showSteps) {
System.out.print("New W2: ");
printMatrix(W2);
System.out.println();
}
}
// Train for one step
public double trainStep(double[][] X, double[][] y, double learningRate, boolean showSteps) {
// Forward pass
double[][] yPred = forwardPass(X, showSteps);
// Compute loss
double loss = computeLoss(yPred, y, showSteps);
// Backward pass
backwardPass(X, y, learningRate, showSteps);
return loss;
}
// Helper: Sum rows of matrix
private double[][] sumRows(double[][] matrix) {
double[][] result = new double[1][matrix[0].length];
for (int j = 0; j < matrix[0].length; j++) {
for (int i = 0; i < matrix.length; i++) {
result[0][j] += matrix[i][j];
}
}
return result;
}
// Helper: Multiply matrix by scalar
private double[][] scalarMultiply(double[][] matrix, double scalar) {
double[][] result = new double[matrix.length][matrix[0].length];
for (int i = 0; i < matrix.length; i++) {
for (int j = 0; j < matrix[0].length; j++) {
result[i][j] = matrix[i][j] * scalar;
}
}
return result;
}
// Print matrix
private void printMatrix(double[][] matrix) {
System.out.print("[");
for (int i = 0; i < matrix.length; i++) {
if (i > 0) System.out.print(" ");
System.out.print("[");
for (int j = 0; j < matrix[0].length; j++) {
System.out.printf("%.4f", matrix[i][j]);
if (j < matrix[0].length - 1) System.out.print(", ");
}
System.out.print("]");
if (i < matrix.length - 1) System.out.println();
}
System.out.println("]");
}
// Main demonstration
public static void main(String[] args) {
int iterations = Integer.parseInt(args[0]);
System.out.println("NEURAL NETWORK STEP-BY-STEP MATH DEMONSTRATION");
System.out.println("============================================================");
// Create XOR dataset
double[][] X = {{0, 0}, {0, 1}, {1, 0}, {1, 1}};
double[][] y = {{0}, {1}, {1}, {0}};
System.out.println("Training Data (XOR Problem):");
System.out.println("Inputs (X):");
for (int i = 0; i < X.length; i++) {
System.out.println(Arrays.toString(X[i]));
}
System.out.println("Expected Outputs (y):");
for (int i = 0; i < y.length; i++) {
System.out.println(y[i][0]);
}
System.out.println("\n============================================================");
// Create network
SimpleNeuralNetwork nn = new SimpleNeuralNetwork(2, 3, 1);
// Show one complete training step in detail
System.out.println("TRAINING STEP 1 - DETAILED WALKTHROUGH");
System.out.println("============================================================");
double[][] X_single = {{X[0][0], X[0][1]}};
double[][] y_single = {{y[0][0]}};
double loss = nn.trainStep(X_single, y_single, 0.5, true);
System.out.println("============================================================");
System.out.printf("LOSS AFTER STEP 1: %.6f\n", loss);
System.out.println("============================================================");
// Train for more steps
System.out.println("\nTRAINING PROGRESS (Next 999 steps):");
System.out.println("----------------------------------------");
for (int step = 2; step <= iterations; step++) {
int idx = (step - 1) % X.length;
X_single = new double[][]{{X[idx][0], X[idx][1]}};
y_single = new double[][]{{y[idx][0]}};
loss = nn.trainStep(X_single, y_single, 0.5, false);
if (step % 10000 == 0 || step <= 10) {
System.out.printf("Step %d: Loss = %.6f\n", step, loss);
}
}
// Test the network
System.out.println("\n============================================================");
System.out.println("FINAL NETWORK TESTING");
System.out.println("============================================================");
for (int i = 0; i < X.length; i++) {
X_single = new double[][]{{X[i][0], X[i][1]}};
double[][] prediction = nn.forwardPass(X_single, false);
System.out.printf("Input: %s → Prediction: %.4f, Expected: %.0f\n",
Arrays.toString(X[i]), prediction[0][0], y[i][0]);
}
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment