Skip to content

Instantly share code, notes, and snippets.

@sanaulla123
Created May 10, 2012 17:31
Show Gist options
  • Save sanaulla123/2654600 to your computer and use it in GitHub Desktop.
Save sanaulla123/2654600 to your computer and use it in GitHub Desktop.
Nearnest Neighbour Classifier on Iris dataset
import java.io.IOException;
import java.util.ArrayList;
import java.util.List;
/**
*
* @author mohamed
*/
public class NNAlgorithm {
List<Iris> trainingData;
List<Iris> testData;
List<Iris> irisDataSet;
List<Iris> setosaList;
List<Iris> versicolorList;
List<Iris> virginicaList;
public NNAlgorithm() {
trainingData = new ArrayList<Iris>();
testData = new ArrayList<Iris>();
setosaList = new ArrayList<Iris>();
versicolorList = new ArrayList<Iris>();
virginicaList = new ArrayList<Iris>();
}
private void segregateData(){
setosaList = irisDataSet.subList(0, 50);
versicolorList = irisDataSet.subList(50,100 );
virginicaList = irisDataSet.subList(100,150);
}
private void prepareTestData(){
testData.addAll(setosaList.subList(30, 50));
testData.addAll(versicolorList.subList(30,50));
testData.addAll(virginicaList.subList(30, 50));
}
private void prepareTrainingData(){
trainingData.addAll(setosaList.subList(0, 30));
trainingData.addAll(versicolorList.subList(0,30));
trainingData.addAll(virginicaList.subList(0, 30));
}
private double calculateNNClassificationAccuracy(){
double accuracy = 0;
double correctClassified = 0;
double minimumDistance = 99999999;
IrisType classifiedLabel = IrisType.SETOSA;
for ( Iris test : testData){
minimumDistance = 99999999;
for ( Iris training : trainingData){
double dist = training.distance(test);
if ( dist < minimumDistance){
minimumDistance = dist;
classifiedLabel = training.type;
}
}
if ( test.type == classifiedLabel){
correctClassified++;
}
}
System.out.println(correctClassified);
accuracy = correctClassified/(testData.size());
System.out.println(accuracy);
return accuracy;
}
public static void main(String[] args) throws IOException {
DataReader reader = new DataReader();
NNAlgorithm algorithm = new NNAlgorithm();
algorithm.irisDataSet = reader.getIrisData();
algorithm.segregateData();
algorithm.prepareTestData();
algorithm.prepareTrainingData();
double accuracy = algorithm.calculateNNClassificationAccuracy();
System.out.println("Accuracy: "+accuracy+" or "+ (accuracy*100)+"%");
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment