Last active
December 25, 2015 17:49
-
-
Save Alrecenk/7016330 to your computer and use it in GitHub Desktop.
An optimized rotation forest algorithm for binary classification on fixed length feature vectors.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
/*A rotation forest algorithm for binary classification with fixed length feature vectors. | |
*created by Alrecenk for inductivebias.com Oct 2013 | |
*/ | |
import java.util.ArrayList; | |
import java.util.Arrays; | |
import java.util.Random; | |
public class RotationForestSimple{ | |
double mean[] ; //the mean of each axis for normalization | |
double deviation[];// the standard deviation of each axis for normalization | |
Treenode tree[] ;//the trees in the forest | |
//builds a rotation forest from the given input/output pairs | |
//output should be <=0 for negative cases and >0 for positive cases | |
//seed is the seed for the random number generator | |
//trees is the number of trees in the forest | |
//minpoints is the minimum data points required for a split to be considered at a leaf | |
//max depth is the maximum allowed depth of an leaves | |
public RotationForestSimple(double input[][], double output[], int seed, int trees, int minpoints, int maxdepth){ | |
//calculate mean and standard deviation along each input variable | |
mean = new double[input[0].length]; | |
deviation = new double[mean.length]; | |
double sum[] = new double[mean.length]; | |
double sum2[] = new double[mean.length] ; | |
for(int k=0;k<input.length;k++){ | |
for (int j = 0; j < mean.length; j++){ | |
double val = input[k][j] ; | |
sum[j] +=val; | |
sum2[j]+=val*val ; | |
} | |
} | |
for (int j = 0; j < mean.length; j++){ | |
sum[j] /= input.length; | |
sum2[j] /= input.length; | |
deviation[j] = Math.sqrt(sum2[j] - (sum[j] * sum[j])); | |
mean[j] = sum[j]; | |
} | |
//normalize all data points so that they have mean of zero and deviation of one | |
for(int k=0;k<input.length;k++){ | |
input[k] = copy(input[k]);//copy first so as not to alter input array | |
normalize(input[k], mean, deviation);//normalize in place. | |
} | |
//initialize the forest | |
tree = new Treenode[trees] ; | |
int datapermodel = input.length ; | |
Random rand = new Random(seed) ;//seeded random number generator | |
for(int k=0;k<trees;k++){ | |
ArrayList<Datapoint> treedata = new ArrayList<Datapoint>(datapermodel) ; | |
//bootstrap aggregating of training data | |
for (int j = 0; j < datapermodel; j++){ | |
//add a random data point (with replacement) to the training data for this tree | |
int nj = Math.abs(rand.nextInt())%input.length; | |
Datapoint d = new Datapoint(copy(input[nj]), output[nj] > 0) ; | |
treedata.add(d) ; | |
} | |
//create this tree | |
tree[k] = new Treenode(treedata, minpoints, rand, maxdepth) ; | |
} | |
} | |
//apply the forest to a new input | |
//returns probability of positive case | |
public double apply(double[] input) { | |
input = copy(input);//copy so as not to alter original input | |
normalize(input, mean, deviation); | |
double output = 0 ; | |
for(int k=0;k<tree.length;k++){ | |
output+=tree[k].apply(input) ; | |
} | |
return output/tree.length; | |
} | |
//returns a normally distributed random vector using the box muller transform | |
public static double[] normaldistribution(int dim, Random rand){ | |
double[] axis = new double[dim]; | |
//generate a normally distributed random vector using the Box-Muller transform to guarantee even distrubtion of directions | |
for (int k = 0; k < dim; k++){ | |
axis[k] = Math.sqrt(-2 * Math.log(rand.nextDouble())) * Math.sin(2 * Math.PI * rand.nextDouble()); | |
} | |
return axis; | |
} | |
//makes a vector of length one | |
public static void normalize(double a[]){ | |
double scale = 0 ; | |
for(int k=0;k<a.length;k++){ | |
scale+=a[k]*a[k]; | |
} | |
scale = 1/Math.sqrt(scale); | |
for(int k=0;k<a.length;k++){ | |
a[k]*=scale ; | |
} | |
} | |
//scales points with the given mean and deviation to have mean of zero and deviation of one | |
public static void normalize(double a[], double mean[], double deviation[]){ | |
for(int k=0;k<a.length;k++){ | |
a[k] = (a[k]-mean[k])/deviation[k]; | |
} | |
} | |
//dot product | |
public static double dot(double[] a, double[] b){ | |
double c = 0; | |
for (int k = 0; k < a.length; k++){ | |
c += a[k] * b[k]; | |
} | |
return c; | |
} | |
//copy a vector to a new array | |
public static double[] copy(double[] b){ | |
double[] c = new double[b.length]; | |
for (int k = 0; k < c.length; k++){ | |
c[k] = b[k]; | |
} | |
return c; | |
} | |
//returns a vector = to a*s | |
public static double[] scale(double[] a, double s){ | |
double[] b = new double[a.length]; | |
for (int k = 0; k < a.length; k++){ | |
b[k] = a[k] * s; | |
} | |
return b; | |
} | |
//returns the difference of two vectors = a-b | |
public static double[] subtract(double[] a, double[] b){ | |
double[] c = new double[a.length]; | |
for (int k = 0; k < a.length; k++){ | |
c[k] = a[k] - b[k]; | |
} | |
return c; | |
} | |
//A single node(branch or leaf) of a rotation tree | |
private class Treenode{ | |
//This data makes up the final tree nodes | |
boolean branchnode=false;//whether this node is a branch | |
double[] splitaxis; // the axis this node split on if a branch | |
double splitvalue;//split plane is X dot splitaxis > splitvalue | |
Treenode upper, lower;//child nodes | |
public int totalpositive, totalnegative;//amount of each class at this node | |
//This data is used only during the training process | |
double[][] axis ;//each axis over which splitting is considered | |
Datapoint[][] data ; //the data at this node sorted by each axis | |
//This constructor should only be called for the root node of a new tree as it performs all of the sorting. | |
public Treenode(ArrayList<Datapoint> traindata, int minpoints, Random rand, int maxdepth){ | |
//make the axes | |
int axes = traindata.get(0).input.length ; | |
axis = new double[axes][] ; | |
for (int k = 0; k<axes; k++){ | |
// A normalized normally distributed random vector | |
//is uniformly distributed on an N-sphere | |
axis[k] = normaldistribution(axes,rand) ; | |
//subtract out the projection onto existing axes to get orthogonal axes | |
//this isn't strictly necessary but it seems to help a little | |
for (int j = Math.max(0,k-5); j < k; j++){ | |
axis[k] = subtract(axis[k], scale(axis[j], dot(axis[k], axis[j]))); | |
} | |
normalize(axis[k]); | |
} | |
//make the sorted Datapoint arrays | |
data = new Datapoint[axes][] ; | |
//put all of the data into each ArrayList | |
for (int k = 0; k < axes; k++){ | |
data[k] = new Datapoint[traindata.size()]; | |
for (int j = 0; j < traindata.size(); j++){ | |
Datapoint d = traindata.get(j); | |
data[k][j] = d;//move data point into array | |
d.setsortvalue(d.dot(axis[k]));//set data point to sort on this axis | |
} | |
Arrays.sort(data[k]);//sort thais ArrayList based on this axis | |
} | |
//count up the total positive and total negative instances in the data set | |
totalpositive = 0; | |
totalnegative = 0; | |
for (int j = 0; j < traindata.size(); j++){ | |
Datapoint d = traindata.get(j); | |
if (d.output){ | |
totalpositive++; | |
}else{ | |
totalnegative++; | |
} | |
} | |
//recursively split the tree until done | |
recursivesplit(minpoints, maxdepth); | |
} | |
//This constructor is called internally during the splitting process | |
//the data passed are the data Datapoint objects multiple times but maintained sorted on each axis | |
public Treenode(Datapoint[][] data, double[][] axis, int tp, int tn){ | |
this.data = data; | |
this.axis = axis; | |
branchnode = false; | |
totalpositive = tp; | |
totalnegative = tn; | |
} | |
//Recursively split this tree until maxdepth reached or leaves contain less than minpoints or are a uniform class | |
public void recursivesplit(int minpoints, int maxdepth){ | |
if (maxdepth > 1 && split(minpoints)){//attempt split | |
data = null;//if succeeded attempt split on children | |
lower.recursivesplit(minpoints, maxdepth - 1); | |
upper.recursivesplit(minpoints, maxdepth - 1); | |
} | |
//this node never attempts to split more than once and can clear the training data | |
data = null; | |
} | |
//splits this node if it should and returns whether it did | |
public boolean split(int minpoints){ | |
//if already split or one class or not enough points remaining then don't split | |
if (branchnode || totalpositive == 0 || totalnegative == 0 || totalpositive + totalnegative < minpoints){ | |
return false; | |
}else{ | |
int bestaxis = -1, splitafter=-1; | |
double bestscore = Double.MAX_VALUE;//any valid split will beat no split | |
int bestLp=0, bestLn=0; | |
for (int k = 0; k < data.length; k++){//try each axis | |
int Lp = 0, Ln = 0, Rp = totalpositive, Rn = totalnegative;//reset the +/- counts | |
for (int j = 0; j < data[k].length - 1; j++){//walk through the data points | |
if (data[k][j].output){ | |
Lp++;//update positive counts | |
Rp--; | |
}else{ | |
Ln++;//update negative counts | |
Rn--; | |
} | |
//score by a parabola approximating information gain | |
double score = Lp * Ln / (double)(Lp + Ln) + Rp * Rn / (double)(Rp + Rn); | |
if (score < bestscore){ // lower score is better | |
bestscore = score;//save score | |
bestaxis = k;//save axis | |
splitafter = j;//svale split location | |
bestLp = Lp;//save positives and negatives to left of split | |
bestLn = Ln ;//so they don't need to be counted again later | |
} | |
} | |
} | |
//if we got a valid split | |
if (bestscore < Double.MAX_VALUE){ | |
splitaxis = axis[bestaxis]; | |
//split halfway between the 2 points around the split | |
splitvalue = 0.5 * (data[bestaxis][splitafter].dot(splitaxis) + data[bestaxis][splitafter + 1].dot(splitaxis)); | |
Datapoint[][] lowerdata = new Datapoint[axis.length][] ; | |
Datapoint[][] upperdata = new Datapoint[axis.length][] ; | |
lowerdata[bestaxis] = new Datapoint[splitafter+1] ;//initialize the child data arrays for the split axis | |
upperdata[bestaxis] = new Datapoint[data[bestaxis].length-splitafter-1] ; | |
for (int k = 0; k <= splitafter; k++){//for the lower node data points | |
data[bestaxis][k].setchild(false);//mark which leaf it goes to | |
lowerdata[bestaxis][k] = data[bestaxis][k];//go ahead and separate the split axis | |
} | |
for (int k = splitafter+1; k < data[bestaxis].length; k++){ | |
data[bestaxis][k].setchild(true);//mark which leaf it goes on | |
upperdata[bestaxis][k-splitafter-1] = data[bestaxis][k] ;//go ahead and separate the split axis | |
} | |
//separate all the other axes maintaining sorting | |
for (int k = 0; k < axis.length; k++){ | |
if (k != bestaxis){//we already did bestaxis=k above | |
//initialize the arrays | |
lowerdata[k] = new Datapoint[splitafter + 1]; | |
upperdata[k] = new Datapoint[data[bestaxis].length - splitafter - 1]; | |
//fill the data into these arrays without changing order | |
int lowerindex=0,upperindex=0; | |
for (int j = 0; j < data[k].length; j++){ | |
if (data[k][j].upperchild){//if goes in upper node | |
upperdata[k][upperindex] = data[k][j]; | |
upperindex++;//put in upper node data array | |
}else{//if goes in lower node | |
lowerdata[k][lowerindex] = data[k][j]; | |
lowerindex++;//put in lower node array | |
} | |
} | |
} | |
} | |
//initialize but do not yet split the children | |
lower = new Treenode(lowerdata, axis, bestLp, bestLn); | |
upper = new Treenode(upperdata, axis, totalpositive - bestLp, totalnegative - bestLn); | |
branchnode = true; | |
return true; | |
}else{//if no valid splits found | |
return false ;//return did not split | |
} | |
} | |
} | |
//returns a value in the range of 0 to 1 roughly approximating chance of being positive | |
public double apply(double[] x){ | |
if (!branchnode){ | |
return totalpositive / (double)(totalpositive + totalnegative); | |
}else if (dot(x, splitaxis) < splitvalue){ | |
return lower.apply(x); | |
}else{ | |
return upper.apply(x); | |
} | |
} | |
}//ends embedded Treenode class | |
//Datapoint class holds an input and an output together which can be sorted, | |
//held as shallow copies, and marked for various things | |
private class Datapoint implements Comparable{ | |
public double[] input; | |
public boolean output; | |
public boolean upperchild; | |
public double sortvalue; | |
//initialize a Datapoint | |
public Datapoint(double[] input, boolean output){ | |
this.input = input; | |
this.output = output; | |
} | |
//dot product of this data point's input with the given axis vector | |
public double dot(double[] axis){ | |
double dot = 0; | |
for (int k = 0; k < axis.length; k++){ | |
dot += input[k] * axis[k]; | |
} | |
return dot; | |
} | |
//set which child this data point should go to | |
public void setchild(boolean upper){ | |
upperchild = upper; | |
} | |
//set value to be sorted on | |
public void setsortvalue(double val){ | |
sortvalue = val ; | |
} | |
//makes data points sortable in java | |
public int compareTo(Object obj){ | |
if (sortvalue < ((Datapoint)obj).sortvalue){ | |
return -1; | |
}else if (sortvalue > ((Datapoint)obj).sortvalue){ | |
return 1; | |
}else{ | |
return 0; | |
} | |
} | |
}//ends embedded data point class | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
That is correct. A "Random" object is created with the seed around line 48 in the forest initialization, and that's used for generating all the random numbers in the actual trees. Pseudo-randomness makes debugging much easier.