Skip to content

Instantly share code, notes, and snippets.

@Alrecenk
Created November 5, 2013 04:43
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save Alrecenk/7314019 to your computer and use it in GitHub Desktop.
Save Alrecenk/7314019 to your computer and use it in GitHub Desktop.
The core learning algorithm for the rotation forest that calculates the best split based on approximate information gain.
//splits this node if it should and returns whether it did
//data is assumed to be a set of presorted lists where data[k][j] is the jth element of data when sorted by axis[k]
public boolean split(int minpoints){
//if already split or one class or not enough points remaining then don't split
if (branchnode || totalpositive == 0 || totalnegative == 0 || totalpositive + totalnegative < minpoints){
return false;
}else{
int bestaxis = -1, splitafter=-1;
double bestscore = Double.MAX_VALUE;//any valid split will beat no split
int bestLp=0, bestLn=0;
for (int k = 0; k < data.length; k++){//try each axis
int Lp = 0, Ln = 0, Rp = totalpositive, Rn = totalnegative;//reset the +/- counts
for (int j = 0; j < data[k].length - 1; j++){//walk through the data points
if (data[k][j].output){
Lp++;//update positive counts
Rp--;
}else{
Ln++;//update negative counts
Rn--;
}
//score by a parabola approximating information gain
double score = Lp * Ln / (double)(Lp + Ln) + Rp * Rn / (double)(Rp + Rn);
if (score < bestscore){ // lower score is better
bestscore = score;//save score
bestaxis = k;//save axis
splitafter = j;//svale split location
bestLp = Lp;//save positives and negatives to left of split
bestLn = Ln ;//so they don't need to be counted again later
}
}
}
//if we got a valid split
if (bestscore < Double.MAX_VALUE){
splitaxis = axis[bestaxis];
//split halfway between the 2 points around the split
splitvalue = 0.5 * (data[bestaxis][splitafter].dot(splitaxis) + data[bestaxis][splitafter + 1].dot(splitaxis));
Datapoint[][] lowerdata = new Datapoint[axis.length][] ;
Datapoint[][] upperdata = new Datapoint[axis.length][] ;
lowerdata[bestaxis] = new Datapoint[splitafter+1] ;//initialize the child data arrays for the split axis
upperdata[bestaxis] = new Datapoint[data[bestaxis].length-splitafter-1] ;
for (int k = 0; k <= splitafter; k++){//for the lower node data points
data[bestaxis][k].setchild(false);//mark which leaf it goes to
lowerdata[bestaxis][k] = data[bestaxis][k];//go ahead and separate the split axis
}
for (int k = splitafter+1; k < data[bestaxis].length; k++){
data[bestaxis][k].setchild(true);//mark which leaf it goes on
upperdata[bestaxis][k-splitafter-1] = data[bestaxis][k] ;//go ahead and separate the split axis
}
//separate all the other axes maintaining sorting
for (int k = 0; k < axis.length; k++){
if (k != bestaxis){//we already did bestaxis=k above
//initialize the arrays
lowerdata[k] = new Datapoint[splitafter + 1];
upperdata[k] = new Datapoint[data[bestaxis].length - splitafter - 1];
//fill the data into these arrays without changing order
int lowerindex=0,upperindex=0;
for (int j = 0; j < data[k].length; j++){
if (data[k][j].upperchild){//if goes in upper node
upperdata[k][upperindex] = data[k][j];
upperindex++;//put in upper node data array
}else{//if goes in lower node
lowerdata[k][lowerindex] = data[k][j];
lowerindex++;//put in lower node array
}
}
}
}
//initialize but do not yet split the children
lower = new Treenode(lowerdata, axis, bestLp, bestLn);
upper = new Treenode(upperdata, axis, totalpositive - bestLp, totalnegative - bestLn);
branchnode = true;
return true;
}else{//if no valid splits found
return false ;//return did not split
}
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment