Skip to content

Instantly share code, notes, and snippets.

Embed
What would you like to do?
Functions required for gradient descent to fit a Logistic Regression model.
//starting from w0 searches for a weight vector using gradient descent
//and Wolfe condition line-search until the gradient magnitude is below tolerance
//or a maximum number of iterations is reached.
public double[] gradientDescent(double w0[], double tolerance, int maxiter){
double w[] = w0 ;
double gradient[] = gradient(w0) ;
int iteration = 0 ;
while(Math.sqrt(dot(gradient,gradient)) > tolerance && iteration < maxiter){
iteration++ ;
//calculate step-size in direction of negative gradient
double alpha = stepSize(this, w, scale(gradient,-1), 1, 500, 0.1, 0.9) ;
w = add( w, scale( gradient, -alpha)) ; // apply step
gradient = gradient(w) ; // get new gradient
}
return w ;
}
//returns the error of a logistic regressions with weights w on the given input and output
//output should be in the form 0 for negative, 1 for positive
public double error(double w[]){
feval++;//keep track of how many times this has been called
double error = 0 ;
for(int k=0;k<input.length;k++){
double diff = s( dot(w,input[k]) ) - output[k] ;
error += diff*diff ;
}
return error ;
}
//returns the gradient of error with respect to weights
//for a logistic regression with weights w on the given input and output
//output should be in the form 0 for negative, 1 for positive
public double[] gradient(double w[]){
geval++;//keep track of how many times this has been called
double g[] = new double[w.length] ;
for(int k=0;k<input.length;k++){
double dot = dot(w,input[k]) ;
double coef = 2 * ( s(dot) - output[k] ) * ds(dot) ;
for(int j=0;j<g.length;j++){
g[j] += input[k][j] * coef ;
}
}
return g ;
}
//returns a numerically calculated gradient - approximation to above
//used only for unit testing gradient, not called in final version
public double[] numericalGradient(double w[], double epsilon){
double g[] = new double[w.length] ;
for(int j=0;j<g.length;j++){
w[j]+=epsilon ;
g[j] = error(w) ;
w[j] -= 2*epsilon ;
g[j] -= error(w) ;
w[j] +=epsilon ;
g[j] /= 2*epsilon ;
}
return g ;
}
//sigmoid/logistic function
public static double s(double x){
double ex = Math.exp(x);
return ex / (ex+1) ;
}
//derivative of sigmoid/logistic function
public static double ds(double x){
double ex = Math.exp(x);
return ex / ( (ex+1)*(ex+1)) ;
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
You can’t perform that action at this time.