{{ message }}

Instantly share code, notes, and snippets.

Alrecenk/RotationForestSimple.java

Last active Dec 25, 2015
An optimized rotation forest algorithm for binary classification on fixed length feature vectors.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters. Learn more about bidirectional Unicode characters
 /*A rotation forest algorithm for binary classification with fixed length feature vectors. *created by Alrecenk for inductivebias.com Oct 2013 */ import java.util.ArrayList; import java.util.Arrays; import java.util.Random; public class RotationForestSimple{ double mean[] ; //the mean of each axis for normalization double deviation[];// the standard deviation of each axis for normalization Treenode tree[] ;//the trees in the forest //builds a rotation forest from the given input/output pairs //output should be <=0 for negative cases and >0 for positive cases //seed is the seed for the random number generator //trees is the number of trees in the forest //minpoints is the minimum data points required for a split to be considered at a leaf //max depth is the maximum allowed depth of an leaves public RotationForestSimple(double input[][], double output[], int seed, int trees, int minpoints, int maxdepth){ //calculate mean and standard deviation along each input variable mean = new double[input[0].length]; deviation = new double[mean.length]; double sum[] = new double[mean.length]; double sum2[] = new double[mean.length] ; for(int k=0;k treedata = new ArrayList(datapermodel) ; //bootstrap aggregating of training data for (int j = 0; j < datapermodel; j++){ //add a random data point (with replacement) to the training data for this tree int nj = Math.abs(rand.nextInt())%input.length; Datapoint d = new Datapoint(copy(input[nj]), output[nj] > 0) ; treedata.add(d) ; } //create this tree tree[k] = new Treenode(treedata, minpoints, rand, maxdepth) ; } } //apply the forest to a new input //returns probability of positive case public double apply(double[] input) { input = copy(input);//copy so as not to alter original input normalize(input, mean, deviation); double output = 0 ; for(int k=0;k splitvalue Treenode upper, lower;//child nodes public int totalpositive, totalnegative;//amount of each class at this node //This data is used only during the training process double[][] axis ;//each axis over which splitting is considered Datapoint[][] data ; //the data at this node sorted by each axis //This constructor should only be called for the root node of a new tree as it performs all of the sorting. public Treenode(ArrayList traindata, int minpoints, Random rand, int maxdepth){ //make the axes int axes = traindata.get(0).input.length ; axis = new double[axes][] ; for (int k = 0; k 1 && split(minpoints)){//attempt split data = null;//if succeeded attempt split on children lower.recursivesplit(minpoints, maxdepth - 1); upper.recursivesplit(minpoints, maxdepth - 1); } //this node never attempts to split more than once and can clear the training data data = null; } //splits this node if it should and returns whether it did public boolean split(int minpoints){ //if already split or one class or not enough points remaining then don't split if (branchnode || totalpositive == 0 || totalnegative == 0 || totalpositive + totalnegative < minpoints){ return false; }else{ int bestaxis = -1, splitafter=-1; double bestscore = Double.MAX_VALUE;//any valid split will beat no split int bestLp=0, bestLn=0; for (int k = 0; k < data.length; k++){//try each axis int Lp = 0, Ln = 0, Rp = totalpositive, Rn = totalnegative;//reset the +/- counts for (int j = 0; j < data[k].length - 1; j++){//walk through the data points if (data[k][j].output){ Lp++;//update positive counts Rp--; }else{ Ln++;//update negative counts Rn--; } //score by a parabola approximating information gain double score = Lp * Ln / (double)(Lp + Ln) + Rp * Rn / (double)(Rp + Rn); if (score < bestscore){ // lower score is better bestscore = score;//save score bestaxis = k;//save axis splitafter = j;//svale split location bestLp = Lp;//save positives and negatives to left of split bestLn = Ln ;//so they don't need to be counted again later } } } //if we got a valid split if (bestscore < Double.MAX_VALUE){ splitaxis = axis[bestaxis]; //split halfway between the 2 points around the split splitvalue = 0.5 * (data[bestaxis][splitafter].dot(splitaxis) + data[bestaxis][splitafter + 1].dot(splitaxis)); Datapoint[][] lowerdata = new Datapoint[axis.length][] ; Datapoint[][] upperdata = new Datapoint[axis.length][] ; lowerdata[bestaxis] = new Datapoint[splitafter+1] ;//initialize the child data arrays for the split axis upperdata[bestaxis] = new Datapoint[data[bestaxis].length-splitafter-1] ; for (int k = 0; k <= splitafter; k++){//for the lower node data points data[bestaxis][k].setchild(false);//mark which leaf it goes to lowerdata[bestaxis][k] = data[bestaxis][k];//go ahead and separate the split axis } for (int k = splitafter+1; k < data[bestaxis].length; k++){ data[bestaxis][k].setchild(true);//mark which leaf it goes on upperdata[bestaxis][k-splitafter-1] = data[bestaxis][k] ;//go ahead and separate the split axis } //separate all the other axes maintaining sorting for (int k = 0; k < axis.length; k++){ if (k != bestaxis){//we already did bestaxis=k above //initialize the arrays lowerdata[k] = new Datapoint[splitafter + 1]; upperdata[k] = new Datapoint[data[bestaxis].length - splitafter - 1]; //fill the data into these arrays without changing order int lowerindex=0,upperindex=0; for (int j = 0; j < data[k].length; j++){ if (data[k][j].upperchild){//if goes in upper node upperdata[k][upperindex] = data[k][j]; upperindex++;//put in upper node data array }else{//if goes in lower node lowerdata[k][lowerindex] = data[k][j]; lowerindex++;//put in lower node array } } } } //initialize but do not yet split the children lower = new Treenode(lowerdata, axis, bestLp, bestLn); upper = new Treenode(upperdata, axis, totalpositive - bestLp, totalnegative - bestLn); branchnode = true; return true; }else{//if no valid splits found return false ;//return did not split } } } //returns a value in the range of 0 to 1 roughly approximating chance of being positive public double apply(double[] x){ if (!branchnode){ return totalpositive / (double)(totalpositive + totalnegative); }else if (dot(x, splitaxis) < splitvalue){ return lower.apply(x); }else{ return upper.apply(x); } } }//ends embedded Treenode class //Datapoint class holds an input and an output together which can be sorted, //held as shallow copies, and marked for various things private class Datapoint implements Comparable{ public double[] input; public boolean output; public boolean upperchild; public double sortvalue; //initialize a Datapoint public Datapoint(double[] input, boolean output){ this.input = input; this.output = output; } //dot product of this data point's input with the given axis vector public double dot(double[] axis){ double dot = 0; for (int k = 0; k < axis.length; k++){ dot += input[k] * axis[k]; } return dot; } //set which child this data point should go to public void setchild(boolean upper){ upperchild = upper; } //set value to be sorted on public void setsortvalue(double val){ sortvalue = val ; } //makes data points sortable in java public int compareTo(Object obj){ if (sortvalue < ((Datapoint)obj).sortvalue){ return -1; }else if (sortvalue > ((Datapoint)obj).sortvalue){ return 1; }else{ return 0; } } }//ends embedded data point class }

jakl commented Oct 16, 2013

 LGTM Just wondering how the seed could be used. It allows reproducible output from the same inputs right? And the same seed is used for all the rand calls?

Alrecenk commented Oct 16, 2013

 That is correct. A "Random" object is created with the seed around line 48 in the forest initialization, and that's used for generating all the random numbers in the actual trees. Pseudo-randomness makes debugging much easier.