Skip to content

Instantly share code, notes, and snippets.

@deanhu2
Forked from ngopal/sgdlinreg.java
Created January 4, 2018 10:47
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save deanhu2/48a3f61ee257d577c2cb11a211a8d867 to your computer and use it in GitHub Desktop.
Save deanhu2/48a3f61ee257d577c2cb11a211a8d867 to your computer and use it in GitHub Desktop.
Stochastic Gradient Descent, but this time in Java.
package com.nikhilgopal.spark;
/**
* Created by nikhilgopal on 3/24/17.
*/
public class SGDLinReg {
public static void main(String[] args) {
double[] coefficients = {0.4, 0.8};
double[][] dataset = {
{1.0, 1.0},
{2.0, 3.0},
{4.0, 3.0},
{3.0, 2.0},
{5.0, 5.0}
};
double[] newcoef = coef_sgd(dataset, 0.001, 500);
System.out.println(newcoef[0] + " " + newcoef[1]);
}
private static double[] coef_sgd(double[][] train, double l_rate, double n_epoch) {
double[] coefficients = {0.0, 0.0};
for (int e = 0; e < n_epoch; e++) {
double sum_error = 0.0;
for (int w = 0; w < train.length; w++) {
double yhat = predict(train[w], coefficients);
double error = yhat - train[w][1];
sum_error += error*error;
coefficients[0] = coefficients[0] = l_rate*error;
for (int k = 0; k < (train[w].length-1); k++) {
coefficients[k+1] = coefficients[k+1] - l_rate * error * train[w][k];
}
}
System.out.println("EPOCH " + e + " LRATE " + l_rate + " ERR " + sum_error);
}
return coefficients;
}
private static double predict(double[] row, double[] coef) {
double yhat = coef[0];
for (int i = 0; i < (coef.length-1); i++) {
yhat += coef[i+1] * row[i];
}
return yhat;
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment