Skip to content

Instantly share code, notes, and snippets.

What would you like to do?
The core learning algorithm for the rotation forest that calculates the best split based on approximate information gain.
//splits this node if it should and returns whether it did
//data is assumed to be a set of presorted lists where data[k][j] is the jth element of data when sorted by axis[k]
public boolean split(int minpoints){
//if already split or one class or not enough points remaining then don't split
if (branchnode || totalpositive == 0 || totalnegative == 0 || totalpositive + totalnegative < minpoints){
return false;
int bestaxis = -1, splitafter=-1;
double bestscore = Double.MAX_VALUE;//any valid split will beat no split
int bestLp=0, bestLn=0;
for (int k = 0; k < data.length; k++){//try each axis
int Lp = 0, Ln = 0, Rp = totalpositive, Rn = totalnegative;//reset the +/- counts
for (int j = 0; j < data[k].length - 1; j++){//walk through the data points
if (data[k][j].output){
Lp++;//update positive counts
Ln++;//update negative counts
//score by a parabola approximating information gain
double score = Lp * Ln / (double)(Lp + Ln) + Rp * Rn / (double)(Rp + Rn);
if (score < bestscore){ // lower score is better
bestscore = score;//save score
bestaxis = k;//save axis
splitafter = j;//svale split location
bestLp = Lp;//save positives and negatives to left of split
bestLn = Ln ;//so they don't need to be counted again later
//if we got a valid split
if (bestscore < Double.MAX_VALUE){
splitaxis = axis[bestaxis];
//split halfway between the 2 points around the split
splitvalue = 0.5 * (data[bestaxis][splitafter].dot(splitaxis) + data[bestaxis][splitafter + 1].dot(splitaxis));
Datapoint[][] lowerdata = new Datapoint[axis.length][] ;
Datapoint[][] upperdata = new Datapoint[axis.length][] ;
lowerdata[bestaxis] = new Datapoint[splitafter+1] ;//initialize the child data arrays for the split axis
upperdata[bestaxis] = new Datapoint[data[bestaxis].length-splitafter-1] ;
for (int k = 0; k <= splitafter; k++){//for the lower node data points
data[bestaxis][k].setchild(false);//mark which leaf it goes to
lowerdata[bestaxis][k] = data[bestaxis][k];//go ahead and separate the split axis
for (int k = splitafter+1; k < data[bestaxis].length; k++){
data[bestaxis][k].setchild(true);//mark which leaf it goes on
upperdata[bestaxis][k-splitafter-1] = data[bestaxis][k] ;//go ahead and separate the split axis
//separate all the other axes maintaining sorting
for (int k = 0; k < axis.length; k++){
if (k != bestaxis){//we already did bestaxis=k above
//initialize the arrays
lowerdata[k] = new Datapoint[splitafter + 1];
upperdata[k] = new Datapoint[data[bestaxis].length - splitafter - 1];
//fill the data into these arrays without changing order
int lowerindex=0,upperindex=0;
for (int j = 0; j < data[k].length; j++){
if (data[k][j].upperchild){//if goes in upper node
upperdata[k][upperindex] = data[k][j];
upperindex++;//put in upper node data array
}else{//if goes in lower node
lowerdata[k][lowerindex] = data[k][j];
lowerindex++;//put in lower node array
//initialize but do not yet split the children
lower = new Treenode(lowerdata, axis, bestLp, bestLn);
upper = new Treenode(upperdata, axis, totalpositive - bestLp, totalnegative - bestLn);
branchnode = true;
return true;
}else{//if no valid splits found
return false ;//return did not split
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
You can’t perform that action at this time.