Skip to content

Instantly share code, notes, and snippets.

@Alrecenk
Last active December 25, 2015 17:49
Show Gist options
  • Save Alrecenk/7016330 to your computer and use it in GitHub Desktop.
Save Alrecenk/7016330 to your computer and use it in GitHub Desktop.
An optimized rotation forest algorithm for binary classification on fixed length feature vectors.
/*A rotation forest algorithm for binary classification with fixed length feature vectors.
*created by Alrecenk for inductivebias.com Oct 2013
*/
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Random;
public class RotationForestSimple{
double mean[] ; //the mean of each axis for normalization
double deviation[];// the standard deviation of each axis for normalization
Treenode tree[] ;//the trees in the forest
//builds a rotation forest from the given input/output pairs
//output should be <=0 for negative cases and >0 for positive cases
//seed is the seed for the random number generator
//trees is the number of trees in the forest
//minpoints is the minimum data points required for a split to be considered at a leaf
//max depth is the maximum allowed depth of an leaves
public RotationForestSimple(double input[][], double output[], int seed, int trees, int minpoints, int maxdepth){
//calculate mean and standard deviation along each input variable
mean = new double[input[0].length];
deviation = new double[mean.length];
double sum[] = new double[mean.length];
double sum2[] = new double[mean.length] ;
for(int k=0;k<input.length;k++){
for (int j = 0; j < mean.length; j++){
double val = input[k][j] ;
sum[j] +=val;
sum2[j]+=val*val ;
}
}
for (int j = 0; j < mean.length; j++){
sum[j] /= input.length;
sum2[j] /= input.length;
deviation[j] = Math.sqrt(sum2[j] - (sum[j] * sum[j]));
mean[j] = sum[j];
}
//normalize all data points so that they have mean of zero and deviation of one
for(int k=0;k<input.length;k++){
input[k] = copy(input[k]);//copy first so as not to alter input array
normalize(input[k], mean, deviation);//normalize in place.
}
//initialize the forest
tree = new Treenode[trees] ;
int datapermodel = input.length ;
Random rand = new Random(seed) ;//seeded random number generator
for(int k=0;k<trees;k++){
ArrayList<Datapoint> treedata = new ArrayList<Datapoint>(datapermodel) ;
//bootstrap aggregating of training data
for (int j = 0; j < datapermodel; j++){
//add a random data point (with replacement) to the training data for this tree
int nj = Math.abs(rand.nextInt())%input.length;
Datapoint d = new Datapoint(copy(input[nj]), output[nj] > 0) ;
treedata.add(d) ;
}
//create this tree
tree[k] = new Treenode(treedata, minpoints, rand, maxdepth) ;
}
}
//apply the forest to a new input
//returns probability of positive case
public double apply(double[] input) {
input = copy(input);//copy so as not to alter original input
normalize(input, mean, deviation);
double output = 0 ;
for(int k=0;k<tree.length;k++){
output+=tree[k].apply(input) ;
}
return output/tree.length;
}
//returns a normally distributed random vector using the box muller transform
public static double[] normaldistribution(int dim, Random rand){
double[] axis = new double[dim];
//generate a normally distributed random vector using the Box-Muller transform to guarantee even distrubtion of directions
for (int k = 0; k < dim; k++){
axis[k] = Math.sqrt(-2 * Math.log(rand.nextDouble())) * Math.sin(2 * Math.PI * rand.nextDouble());
}
return axis;
}
//makes a vector of length one
public static void normalize(double a[]){
double scale = 0 ;
for(int k=0;k<a.length;k++){
scale+=a[k]*a[k];
}
scale = 1/Math.sqrt(scale);
for(int k=0;k<a.length;k++){
a[k]*=scale ;
}
}
//scales points with the given mean and deviation to have mean of zero and deviation of one
public static void normalize(double a[], double mean[], double deviation[]){
for(int k=0;k<a.length;k++){
a[k] = (a[k]-mean[k])/deviation[k];
}
}
//dot product
public static double dot(double[] a, double[] b){
double c = 0;
for (int k = 0; k < a.length; k++){
c += a[k] * b[k];
}
return c;
}
//copy a vector to a new array
public static double[] copy(double[] b){
double[] c = new double[b.length];
for (int k = 0; k < c.length; k++){
c[k] = b[k];
}
return c;
}
//returns a vector = to a*s
public static double[] scale(double[] a, double s){
double[] b = new double[a.length];
for (int k = 0; k < a.length; k++){
b[k] = a[k] * s;
}
return b;
}
//returns the difference of two vectors = a-b
public static double[] subtract(double[] a, double[] b){
double[] c = new double[a.length];
for (int k = 0; k < a.length; k++){
c[k] = a[k] - b[k];
}
return c;
}
//A single node(branch or leaf) of a rotation tree
private class Treenode{
//This data makes up the final tree nodes
boolean branchnode=false;//whether this node is a branch
double[] splitaxis; // the axis this node split on if a branch
double splitvalue;//split plane is X dot splitaxis > splitvalue
Treenode upper, lower;//child nodes
public int totalpositive, totalnegative;//amount of each class at this node
//This data is used only during the training process
double[][] axis ;//each axis over which splitting is considered
Datapoint[][] data ; //the data at this node sorted by each axis
//This constructor should only be called for the root node of a new tree as it performs all of the sorting.
public Treenode(ArrayList<Datapoint> traindata, int minpoints, Random rand, int maxdepth){
//make the axes
int axes = traindata.get(0).input.length ;
axis = new double[axes][] ;
for (int k = 0; k<axes; k++){
// A normalized normally distributed random vector
//is uniformly distributed on an N-sphere
axis[k] = normaldistribution(axes,rand) ;
//subtract out the projection onto existing axes to get orthogonal axes
//this isn't strictly necessary but it seems to help a little
for (int j = Math.max(0,k-5); j < k; j++){
axis[k] = subtract(axis[k], scale(axis[j], dot(axis[k], axis[j])));
}
normalize(axis[k]);
}
//make the sorted Datapoint arrays
data = new Datapoint[axes][] ;
//put all of the data into each ArrayList
for (int k = 0; k < axes; k++){
data[k] = new Datapoint[traindata.size()];
for (int j = 0; j < traindata.size(); j++){
Datapoint d = traindata.get(j);
data[k][j] = d;//move data point into array
d.setsortvalue(d.dot(axis[k]));//set data point to sort on this axis
}
Arrays.sort(data[k]);//sort thais ArrayList based on this axis
}
//count up the total positive and total negative instances in the data set
totalpositive = 0;
totalnegative = 0;
for (int j = 0; j < traindata.size(); j++){
Datapoint d = traindata.get(j);
if (d.output){
totalpositive++;
}else{
totalnegative++;
}
}
//recursively split the tree until done
recursivesplit(minpoints, maxdepth);
}
//This constructor is called internally during the splitting process
//the data passed are the data Datapoint objects multiple times but maintained sorted on each axis
public Treenode(Datapoint[][] data, double[][] axis, int tp, int tn){
this.data = data;
this.axis = axis;
branchnode = false;
totalpositive = tp;
totalnegative = tn;
}
//Recursively split this tree until maxdepth reached or leaves contain less than minpoints or are a uniform class
public void recursivesplit(int minpoints, int maxdepth){
if (maxdepth > 1 && split(minpoints)){//attempt split
data = null;//if succeeded attempt split on children
lower.recursivesplit(minpoints, maxdepth - 1);
upper.recursivesplit(minpoints, maxdepth - 1);
}
//this node never attempts to split more than once and can clear the training data
data = null;
}
//splits this node if it should and returns whether it did
public boolean split(int minpoints){
//if already split or one class or not enough points remaining then don't split
if (branchnode || totalpositive == 0 || totalnegative == 0 || totalpositive + totalnegative < minpoints){
return false;
}else{
int bestaxis = -1, splitafter=-1;
double bestscore = Double.MAX_VALUE;//any valid split will beat no split
int bestLp=0, bestLn=0;
for (int k = 0; k < data.length; k++){//try each axis
int Lp = 0, Ln = 0, Rp = totalpositive, Rn = totalnegative;//reset the +/- counts
for (int j = 0; j < data[k].length - 1; j++){//walk through the data points
if (data[k][j].output){
Lp++;//update positive counts
Rp--;
}else{
Ln++;//update negative counts
Rn--;
}
//score by a parabola approximating information gain
double score = Lp * Ln / (double)(Lp + Ln) + Rp * Rn / (double)(Rp + Rn);
if (score < bestscore){ // lower score is better
bestscore = score;//save score
bestaxis = k;//save axis
splitafter = j;//svale split location
bestLp = Lp;//save positives and negatives to left of split
bestLn = Ln ;//so they don't need to be counted again later
}
}
}
//if we got a valid split
if (bestscore < Double.MAX_VALUE){
splitaxis = axis[bestaxis];
//split halfway between the 2 points around the split
splitvalue = 0.5 * (data[bestaxis][splitafter].dot(splitaxis) + data[bestaxis][splitafter + 1].dot(splitaxis));
Datapoint[][] lowerdata = new Datapoint[axis.length][] ;
Datapoint[][] upperdata = new Datapoint[axis.length][] ;
lowerdata[bestaxis] = new Datapoint[splitafter+1] ;//initialize the child data arrays for the split axis
upperdata[bestaxis] = new Datapoint[data[bestaxis].length-splitafter-1] ;
for (int k = 0; k <= splitafter; k++){//for the lower node data points
data[bestaxis][k].setchild(false);//mark which leaf it goes to
lowerdata[bestaxis][k] = data[bestaxis][k];//go ahead and separate the split axis
}
for (int k = splitafter+1; k < data[bestaxis].length; k++){
data[bestaxis][k].setchild(true);//mark which leaf it goes on
upperdata[bestaxis][k-splitafter-1] = data[bestaxis][k] ;//go ahead and separate the split axis
}
//separate all the other axes maintaining sorting
for (int k = 0; k < axis.length; k++){
if (k != bestaxis){//we already did bestaxis=k above
//initialize the arrays
lowerdata[k] = new Datapoint[splitafter + 1];
upperdata[k] = new Datapoint[data[bestaxis].length - splitafter - 1];
//fill the data into these arrays without changing order
int lowerindex=0,upperindex=0;
for (int j = 0; j < data[k].length; j++){
if (data[k][j].upperchild){//if goes in upper node
upperdata[k][upperindex] = data[k][j];
upperindex++;//put in upper node data array
}else{//if goes in lower node
lowerdata[k][lowerindex] = data[k][j];
lowerindex++;//put in lower node array
}
}
}
}
//initialize but do not yet split the children
lower = new Treenode(lowerdata, axis, bestLp, bestLn);
upper = new Treenode(upperdata, axis, totalpositive - bestLp, totalnegative - bestLn);
branchnode = true;
return true;
}else{//if no valid splits found
return false ;//return did not split
}
}
}
//returns a value in the range of 0 to 1 roughly approximating chance of being positive
public double apply(double[] x){
if (!branchnode){
return totalpositive / (double)(totalpositive + totalnegative);
}else if (dot(x, splitaxis) < splitvalue){
return lower.apply(x);
}else{
return upper.apply(x);
}
}
}//ends embedded Treenode class
//Datapoint class holds an input and an output together which can be sorted,
//held as shallow copies, and marked for various things
private class Datapoint implements Comparable{
public double[] input;
public boolean output;
public boolean upperchild;
public double sortvalue;
//initialize a Datapoint
public Datapoint(double[] input, boolean output){
this.input = input;
this.output = output;
}
//dot product of this data point's input with the given axis vector
public double dot(double[] axis){
double dot = 0;
for (int k = 0; k < axis.length; k++){
dot += input[k] * axis[k];
}
return dot;
}
//set which child this data point should go to
public void setchild(boolean upper){
upperchild = upper;
}
//set value to be sorted on
public void setsortvalue(double val){
sortvalue = val ;
}
//makes data points sortable in java
public int compareTo(Object obj){
if (sortvalue < ((Datapoint)obj).sortvalue){
return -1;
}else if (sortvalue > ((Datapoint)obj).sortvalue){
return 1;
}else{
return 0;
}
}
}//ends embedded data point class
}
@Alrecenk
Copy link
Author

That is correct. A "Random" object is created with the seed around line 48 in the forest initialization, and that's used for generating all the random numbers in the actual trees. Pseudo-randomness makes debugging much easier.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment