Skip to content

Instantly share code, notes, and snippets.

@Alrecenk
Last active December 25, 2015 17:49
Show Gist options
  • Save Alrecenk/7016330 to your computer and use it in GitHub Desktop.
Save Alrecenk/7016330 to your computer and use it in GitHub Desktop.
An optimized rotation forest algorithm for binary classification on fixed length feature vectors.
/*A rotation forest algorithm for binary classification with fixed length feature vectors.
*created by Alrecenk for inductivebias.com Oct 2013
*/
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Random;
public class RotationForestSimple{
double mean[] ; //the mean of each axis for normalization
double deviation[];// the standard deviation of each axis for normalization
Treenode tree[] ;//the trees in the forest
//builds a rotation forest from the given input/output pairs
//output should be <=0 for negative cases and >0 for positive cases
//seed is the seed for the random number generator
//trees is the number of trees in the forest
//minpoints is the minimum data points required for a split to be considered at a leaf
//max depth is the maximum allowed depth of an leaves
public RotationForestSimple(double input[][], double output[], int seed, int trees, int minpoints, int maxdepth){
//calculate mean and standard deviation along each input variable
mean = new double[input[0].length];
deviation = new double[mean.length];
double sum[] = new double[mean.length];
double sum2[] = new double[mean.length] ;
for(int k=0;k<input.length;k++){
for (int j = 0; j < mean.length; j++){
double val = input[k][j] ;
sum[j] +=val;
sum2[j]+=val*val ;
}
}
for (int j = 0; j < mean.length; j++){
sum[j] /= input.length;
sum2[j] /= input.length;
deviation[j] = Math.sqrt(sum2[j] - (sum[j] * sum[j]));
mean[j] = sum[j];
}
//normalize all data points so that they have mean of zero and deviation of one
for(int k=0;k<input.length;k++){
input[k] = copy(input[k]);//copy first so as not to alter input array
normalize(input[k], mean, deviation);//normalize in place.
}
//initialize the forest
tree = new Treenode[trees] ;
int datapermodel = input.length ;
Random rand = new Random(seed) ;//seeded random number generator
for(int k=0;k<trees;k++){
ArrayList<Datapoint> treedata = new ArrayList<Datapoint>(datapermodel) ;
//bootstrap aggregating of training data
for (int j = 0; j < datapermodel; j++){
//add a random data point (with replacement) to the training data for this tree
int nj = Math.abs(rand.nextInt())%input.length;
Datapoint d = new Datapoint(copy(input[nj]), output[nj] > 0) ;
treedata.add(d) ;
}
//create this tree
tree[k] = new Treenode(treedata, minpoints, rand, maxdepth) ;
}
}
//apply the forest to a new input
//returns probability of positive case
public double apply(double[] input) {
input = copy(input);//copy so as not to alter original input
normalize(input, mean, deviation);
double output = 0 ;
for(int k=0;k<tree.length;k++){
output+=tree[k].apply(input) ;
}
return output/tree.length;
}
//returns a normally distributed random vector using the box muller transform
public static double[] normaldistribution(int dim, Random rand){
double[] axis = new double[dim];
//generate a normally distributed random vector using the Box-Muller transform to guarantee even distrubtion of directions
for (int k = 0; k < dim; k++){
axis[k] = Math.sqrt(-2 * Math.log(rand.nextDouble())) * Math.sin(2 * Math.PI * rand.nextDouble());
}
return axis;
}
//makes a vector of length one
public static void normalize(double a[]){
double scale = 0 ;
for(int k=0;k<a.length;k++){
scale+=a[k]*a[k];
}
scale = 1/Math.sqrt(scale);
for(int k=0;k<a.length;k++){
a[k]*=scale ;
}
}
//scales points with the given mean and deviation to have mean of zero and deviation of one
public static void normalize(double a[], double mean[], double deviation[]){
for(int k=0;k<a.length;k++){
a[k] = (a[k]-mean[k])/deviation[k];
}
}
//dot product
public static double dot(double[] a, double[] b){
double c = 0;
for (int k = 0; k < a.length; k++){
c += a[k] * b[k];
}
return c;
}
//copy a vector to a new array
public static double[] copy(double[] b){
double[] c = new double[b.length];
for (int k = 0; k < c.length; k++){
c[k] = b[k];
}
return c;
}
//returns a vector = to a*s
public static double[] scale(double[] a, double s){
double[] b = new double[a.length];
for (int k = 0; k < a.length; k++){
b[k] = a[k] * s;
}
return b;
}
//returns the difference of two vectors = a-b
public static double[] subtract(double[] a, double[] b){
double[] c = new double[a.length];
for (int k = 0; k < a.length; k++){
c[k] = a[k] - b[k];
}
return c;
}
//A single node(branch or leaf) of a rotation tree
private class Treenode{
//This data makes up the final tree nodes
boolean branchnode=false;//whether this node is a branch
double[] splitaxis; // the axis this node split on if a branch
double splitvalue;//split plane is X dot splitaxis > splitvalue
Treenode upper, lower;//child nodes
public int totalpositive, totalnegative;//amount of each class at this node
//This data is used only during the training process
double[][] axis ;//each axis over which splitting is considered
Datapoint[][] data ; //the data at this node sorted by each axis
//This constructor should only be called for the root node of a new tree as it performs all of the sorting.
public Treenode(ArrayList<Datapoint> traindata, int minpoints, Random rand, int maxdepth){
//make the axes
int axes = traindata.get(0).input.length ;
axis = new double[axes][] ;
for (int k = 0; k<axes; k++){
// A normalized normally distributed random vector
//is uniformly distributed on an N-sphere
axis[k] = normaldistribution(axes,rand) ;
//subtract out the projection onto existing axes to get orthogonal axes
//this isn't strictly necessary but it seems to help a little
for (int j = Math.max(0,k-5); j < k; j++){
axis[k] = subtract(axis[k], scale(axis[j], dot(axis[k], axis[j])));
}
normalize(axis[k]);
}
//make the sorted Datapoint arrays
data = new Datapoint[axes][] ;
//put all of the data into each ArrayList
for (int k = 0; k < axes; k++){
data[k] = new Datapoint[traindata.size()];
for (int j = 0; j < traindata.size(); j++){
Datapoint d = traindata.get(j);
data[k][j] = d;//move data point into array
d.setsortvalue(d.dot(axis[k]));//set data point to sort on this axis
}
Arrays.sort(data[k]);//sort thais ArrayList based on this axis
}
//count up the total positive and total negative instances in the data set
totalpositive = 0;
totalnegative = 0;
for (int j = 0; j < traindata.size(); j++){
Datapoint d = traindata.get(j);
if (d.output){
totalpositive++;
}else{
totalnegative++;
}
}
//recursively split the tree until done
recursivesplit(minpoints, maxdepth);
}
//This constructor is called internally during the splitting process
//the data passed are the data Datapoint objects multiple times but maintained sorted on each axis
public Treenode(Datapoint[][] data, double[][] axis, int tp, int tn){
this.data = data;
this.axis = axis;
branchnode = false;
totalpositive = tp;
totalnegative = tn;
}
//Recursively split this tree until maxdepth reached or leaves contain less than minpoints or are a uniform class
public void recursivesplit(int minpoints, int maxdepth){
if (maxdepth > 1 && split(minpoints)){//attempt split
data = null;//if succeeded attempt split on children
lower.recursivesplit(minpoints, maxdepth - 1);
upper.recursivesplit(minpoints, maxdepth - 1);
}
//this node never attempts to split more than once and can clear the training data
data = null;
}
//splits this node if it should and returns whether it did
public boolean split(int minpoints){
//if already split or one class or not enough points remaining then don't split
if (branchnode || totalpositive == 0 || totalnegative == 0 || totalpositive + totalnegative < minpoints){
return false;
}else{
int bestaxis = -1, splitafter=-1;
double bestscore = Double.MAX_VALUE;//any valid split will beat no split
int bestLp=0, bestLn=0;
for (int k = 0; k < data.length; k++){//try each axis
int Lp = 0, Ln = 0, Rp = totalpositive, Rn = totalnegative;//reset the +/- counts
for (int j = 0; j < data[k].length - 1; j++){//walk through the data points
if (data[k][j].output){
Lp++;//update positive counts
Rp--;
}else{
Ln++;//update negative counts
Rn--;
}
//score by a parabola approximating information gain
double score = Lp * Ln / (double)(Lp + Ln) + Rp * Rn / (double)(Rp + Rn);
if (score < bestscore){ // lower score is better
bestscore = score;//save score
bestaxis = k;//save axis
splitafter = j;//svale split location
bestLp = Lp;//save positives and negatives to left of split
bestLn = Ln ;//so they don't need to be counted again later
}
}
}
//if we got a valid split
if (bestscore < Double.MAX_VALUE){
splitaxis = axis[bestaxis];
//split halfway between the 2 points around the split
splitvalue = 0.5 * (data[bestaxis][splitafter].dot(splitaxis) + data[bestaxis][splitafter + 1].dot(splitaxis));
Datapoint[][] lowerdata = new Datapoint[axis.length][] ;
Datapoint[][] upperdata = new Datapoint[axis.length][] ;
lowerdata[bestaxis] = new Datapoint[splitafter+1] ;//initialize the child data arrays for the split axis
upperdata[bestaxis] = new Datapoint[data[bestaxis].length-splitafter-1] ;
for (int k = 0; k <= splitafter; k++){//for the lower node data points
data[bestaxis][k].setchild(false);//mark which leaf it goes to
lowerdata[bestaxis][k] = data[bestaxis][k];//go ahead and separate the split axis
}
for (int k = splitafter+1; k < data[bestaxis].length; k++){
data[bestaxis][k].setchild(true);//mark which leaf it goes on
upperdata[bestaxis][k-splitafter-1] = data[bestaxis][k] ;//go ahead and separate the split axis
}
//separate all the other axes maintaining sorting
for (int k = 0; k < axis.length; k++){
if (k != bestaxis){//we already did bestaxis=k above
//initialize the arrays
lowerdata[k] = new Datapoint[splitafter + 1];
upperdata[k] = new Datapoint[data[bestaxis].length - splitafter - 1];
//fill the data into these arrays without changing order
int lowerindex=0,upperindex=0;
for (int j = 0; j < data[k].length; j++){
if (data[k][j].upperchild){//if goes in upper node
upperdata[k][upperindex] = data[k][j];
upperindex++;//put in upper node data array
}else{//if goes in lower node
lowerdata[k][lowerindex] = data[k][j];
lowerindex++;//put in lower node array
}
}
}
}
//initialize but do not yet split the children
lower = new Treenode(lowerdata, axis, bestLp, bestLn);
upper = new Treenode(upperdata, axis, totalpositive - bestLp, totalnegative - bestLn);
branchnode = true;
return true;
}else{//if no valid splits found
return false ;//return did not split
}
}
}
//returns a value in the range of 0 to 1 roughly approximating chance of being positive
public double apply(double[] x){
if (!branchnode){
return totalpositive / (double)(totalpositive + totalnegative);
}else if (dot(x, splitaxis) < splitvalue){
return lower.apply(x);
}else{
return upper.apply(x);
}
}
}//ends embedded Treenode class
//Datapoint class holds an input and an output together which can be sorted,
//held as shallow copies, and marked for various things
private class Datapoint implements Comparable{
public double[] input;
public boolean output;
public boolean upperchild;
public double sortvalue;
//initialize a Datapoint
public Datapoint(double[] input, boolean output){
this.input = input;
this.output = output;
}
//dot product of this data point's input with the given axis vector
public double dot(double[] axis){
double dot = 0;
for (int k = 0; k < axis.length; k++){
dot += input[k] * axis[k];
}
return dot;
}
//set which child this data point should go to
public void setchild(boolean upper){
upperchild = upper;
}
//set value to be sorted on
public void setsortvalue(double val){
sortvalue = val ;
}
//makes data points sortable in java
public int compareTo(Object obj){
if (sortvalue < ((Datapoint)obj).sortvalue){
return -1;
}else if (sortvalue > ((Datapoint)obj).sortvalue){
return 1;
}else{
return 0;
}
}
}//ends embedded data point class
}
@jakl
Copy link

jakl commented Oct 16, 2013

LGTM

Just wondering how the seed could be used. It allows reproducible output from the same inputs right? And the same seed is used for all the rand calls?

@Alrecenk
Copy link
Author

That is correct. A "Random" object is created with the seed around line 48 in the forest initialization, and that's used for generating all the random numbers in the actual trees. Pseudo-randomness makes debugging much easier.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment