Created
July 22, 2014 21:20
-
-
Save leonaburime/69015d84a14fd328495a to your computer and use it in GitHub Desktop.
KMeans algorithm implementation in Java. Imported Apache Commons 'lang' and 'math' library.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
package KMeans; | |
import org.apache.commons.lang3.ArrayUtils; | |
import org.apache.commons.math3.linear.MatrixUtils; | |
import org.apache.commons.math3.linear.RealMatrix; | |
import org.apache.commons.math3.linear.RealVector; | |
import org.apache.commons.math3.stat.StatUtils; | |
import java.lang.reflect.Array; | |
import java.util.*; | |
/** | |
* Created by LASFE using IntelliJ on 7/20/2014. | |
*/ | |
public class KMeans { | |
//Test set comes from http://www.jiaaro.com/KNN-for-humans/ | |
/* | |
red 1 | |
orange 2 | |
yellow 3 | |
green 4 | |
blue 5 | |
purple 6 | |
*/ | |
//The results of the program can be found at http://www.jiaaro.com/KNN-for-humans/ | |
public double[][] input = { | |
//# weight, color, # seeds | |
{303,3,1}, | |
{370,1,2}, | |
{298,3,1}, | |
{277,3,1}, | |
{377,4,2}, | |
{299,3,1}, | |
{382,1,2}, | |
{374,4,6}, | |
{303,4,1}, | |
{309,3,1}, | |
{359,1,2}, | |
{366,1,4}, | |
{311,3,1}, | |
{302,3,1}, | |
{373,4,4}, | |
{305,3,1} | |
}; | |
public String[] output = { | |
//#type | |
"banana", | |
"apple", | |
"banana", | |
"banana", | |
"apple", | |
"banana", | |
"apple", | |
"apple", | |
"banana", | |
"banana", | |
"apple", | |
"apple", | |
"banana", | |
"banana", | |
"apple", | |
"banana", | |
"apple", | |
}; | |
public String[] label; | |
public double[] predict = {371, 3, 6};//Lets see if its a banana or apple | |
public HashMap clusters = new HashMap(); | |
public int k = 2;//Number of classifications- In our case 'apple' or 'banana' | |
public int max_iterations = 1000; | |
//Turn out RealMatrix into a hash with each key being set to each row | |
private static HashMap matrixToHash(RealMatrix mat){ | |
HashMap hash = new HashMap(); | |
for(int i=0; i<mat.getRowDimension(); i++){ | |
hash.put(i, mat.getRow(i)); | |
} | |
return hash; | |
} | |
public static void main(String[] args) { | |
KMeans km = new KMeans(); | |
//Lets create the centroids or 'average' locations of center for our points | |
double[][] centroids; | |
//Lets standardize our input array | |
km.input = MatrixUtils.createRealMatrix(km.input).getData(); | |
//Lets put an array in each of the clusters to append the each {weight, color, # of seeds} to | |
for(int i=0; i<km.k; i++) | |
km.clusters.put(i, new double[km.input[0].length]); | |
centroids = km.solve(); | |
//Now lets predict our test array | |
km.predictClass(centroids); | |
} | |
private void predictClass(double[][] centroids){ | |
int index = euclideanDistance( this.predict, centroids ); | |
System.out.println(Arrays.toString(this.predict) + " is closest to Centroid " + this.label[index]); | |
} | |
private double[][] solve(){ | |
//Let create two random sets of centroids to compare for convergence later | |
double [][] centroids = createRandomCentroids(this.k, this.input); | |
double [][] oldCentroids = createRandomCentroids(this.k, this.input); | |
int iterations = 0; | |
//We need a dynamic array to store our points | |
HashMap<Integer, ArrayList<double[]>> clusters = new HashMap<Integer, ArrayList<double[]>>(); | |
//Lets run the algorithm until it converges or reaches max iterations | |
while( this.converged(oldCentroids, centroids, iterations) != true ){ | |
oldCentroids = centroids; | |
clusters = this.findClosestCentroids( this.input, centroids); | |
centroids = this.getNewCentroids( clusters); | |
//System.out.println( Arrays.deepToString(this.clusters.values().toArray()) ); | |
iterations += 1; | |
} | |
//We need to find out which label belongs to which centroid | |
this.assignLabels(clusters); | |
//Lets print out which cluster each data point ends up in | |
for(int i=0; i< clusters.size(); i++){ | |
System.out.println("\nCluster " + i +"(" + this.label[i] + "): " + Arrays.deepToString( clusters.get(i).toArray(new double[][] {}))); | |
} | |
return centroids; | |
} | |
//Lets assign 'labels' or 'outputs' to each of our 'clusters' or grouped set of points | |
private void assignLabels( HashMap<Integer, ArrayList<double[]>> clusters ){ | |
//Lets turn out list of outputs into a unique set | |
Set mySet = new HashSet(Arrays.asList(this.output)); | |
this.label = new String[clusters.size()]; | |
//Lets take the first point in each cluster, see the its index in the input and use | |
//that index to get the label from the output | |
for(int i=0; i< clusters.size(); i++){ | |
int index = ArrayUtils.indexOf(this.input , clusters.get(i).get(0) ); | |
this.label[i] = this.output[index]; | |
} | |
//System.out.println(Arrays.deepToString( this.label ) ); | |
} | |
//Calculates the mean of the new centroids via the clusters in each group | |
private double[][] getNewCentroids( HashMap<Integer, ArrayList<double[]>> hash){ | |
double[][] newCentroids = new double[hash.size()][]; | |
for(int i=0;i<hash.size();i++){ | |
//Lets create a matrix of each groups points to index them by column easier | |
RealMatrix mat = MatrixUtils.createRealMatrix( hash.get(i).toArray(new double[][] {}) ); | |
double[] mean = new double[mat.getColumnDimension()]; | |
//Now lets iterate through each column(weight, color, type) and set that value to the mean | |
//of our centroid | |
for(int j=0; j< mat.getColumnDimension(); j++){ | |
mean[j]=StatUtils.mean( mat.getColumn(j)); | |
} | |
newCentroids[i]=mean;//Setting the centroids new mean | |
} | |
return newCentroids; | |
} | |
private boolean converged(double [][] oldCentroids, double [][] centroids, int iterations){ | |
//Dont want to iterate forever. Break of the algorithm at 'max_iterations' | |
if(iterations>this.max_iterations ) { | |
System.out.println("Max iterations reached. Returning..."); | |
return true; | |
} | |
//If my old and new centroids are equal after comparing which centroid each data point was equal | |
//to then we have converged | |
if( Arrays.deepEquals(oldCentroids, centroids) ) { | |
System.out.println("Centroids have converged. Returning..."); | |
return true; | |
} | |
return false; | |
} | |
//Creating Random Centroids with values that are in the range of our data points | |
private double[][] createRandomCentroids(int row, double[][] input){ | |
RealMatrix mat = MatrixUtils.createRealMatrix(input) ; | |
int column = input[0].length; | |
//Lets create k centroids that have the same number of indices as our inputs | |
double[][] centroids = new double[row][column]; | |
Random rand = new Random(); | |
for(int i=0;i<row;i++) { | |
for (int j = 0; j < mat.getColumnDimension(); j++) { | |
//Lets get the max and min of each columns | |
double max = mat.getColumnVector(j).getMaxValue(), | |
min = mat.getColumnVector(j).getMinValue(); | |
//Now lets create a random point in between the max and min values of the column | |
centroids[i][j] =min + (max - min) * rand.nextDouble(); | |
} | |
} | |
return centroids; | |
} | |
//We need to find the centroids that have the shortest Euclidean distance to each input | |
private HashMap findClosestCentroids(double[][] input, double [][] centroids){ | |
HashMap<Integer, ArrayList> clusters = new HashMap(); | |
for(int i=0;i<centroids.length;i++)//Lets prepopulate our hash with Arraylists to add arrays | |
clusters.put(i, new ArrayList<double[]>() ); | |
for(double[] arr: input){ | |
//Index of centroid with shorted distance to this input | |
int index = euclideanDistance( arr, centroids ); | |
//Now lets add the input to the centroids cluster grouping | |
clusters.get(index).add(arr); | |
} | |
return clusters; | |
} | |
//Perform Euclidean distance formula to find out the distance | |
//between our prediction value and each row in the matrix | |
public static int euclideanDistance( double[] input, double[][] centroids ){ | |
RealMatrix m = MatrixUtils.createRealMatrix( centroids ); | |
Map<Double, Integer> map = new HashMap(); | |
//Lets turn out 'y' value or label into vector for easier math operations | |
RealVector Y = MatrixUtils.createRealVector( input); | |
for (int i=0; i<m.getRowDimension(); i++){ | |
RealVector vec = m.getRowVector(i); | |
RealVector sub = vec.subtract( Y ); | |
//Take square root of sum of square values that were subtracted a line above | |
double distance = Math.sqrt(StatUtils.sumSq(sub.toArray())); | |
//Use the distance to each data point(or row) as key with the 'default' option as value | |
map.put( distance , i/*cluster number*/ ); | |
} | |
//Now lets sort the map's keys into a set | |
SortedSet<Double> keys = new TreeSet<Double>(map.keySet()); | |
List<Integer> neighbors = new ArrayList<Integer>(); | |
return map.get( keys.first());//Return cluster index of shortest distance | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment