Skip to content

Instantly share code, notes, and snippets.

@loicknuchel
Created November 21, 2013 23:35
Show Gist options
  • Save loicknuchel/7591918 to your computer and use it in GitHub Desktop.
Save loicknuchel/7591918 to your computer and use it in GitHub Desktop.
Writting some code in scala and java about collections.
5.1 3.5 1.4 0.2 Iris-setosa
4.9 3.0 1.4 0.2 Iris-setosa
4.7 3.2 1.3 0.2 Iris-setosa
4.6 3.1 1.5 0.2 Iris-setosa
5.0 3.6 1.4 0.2 Iris-setosa
5.4 3.9 1.7 0.4 Iris-setosa
4.6 3.4 1.4 0.3 Iris-setosa
5.0 3.4 1.5 0.2 Iris-setosa
4.4 2.9 1.4 0.2 Iris-setosa
4.9 3.1 1.5 0.1 Iris-setosa
5.4 3.7 1.5 0.2 Iris-setosa
4.8 3.4 1.6 0.2 Iris-setosa
4.8 3.0 1.4 0.1 Iris-setosa
4.3 3.0 1.1 0.1 Iris-setosa
5.8 4.0 1.2 0.2 Iris-setosa
5.7 4.4 1.5 0.4 Iris-setosa
5.4 3.9 1.3 0.4 Iris-setosa
5.1 3.5 1.4 0.3 Iris-setosa
5.7 3.8 1.7 0.3 Iris-setosa
5.1 3.8 1.5 0.3 Iris-setosa
5.4 3.4 1.7 0.2 Iris-setosa
5.1 3.7 1.5 0.4 Iris-setosa
4.6 3.6 1.0 0.2 Iris-setosa
5.1 3.3 1.7 0.5 Iris-setosa
4.8 3.4 1.9 0.2 Iris-setosa
5.0 3.0 1.6 0.2 Iris-setosa
5.0 3.4 1.6 0.4 Iris-setosa
5.2 3.5 1.5 0.2 Iris-setosa
5.2 3.4 1.4 0.2 Iris-setosa
4.7 3.2 1.6 0.2 Iris-setosa
4.8 3.1 1.6 0.2 Iris-setosa
5.4 3.4 1.5 0.4 Iris-setosa
5.2 4.1 1.5 0.1 Iris-setosa
5.5 4.2 1.4 0.2 Iris-setosa
4.9 3.1 1.5 0.1 Iris-setosa
5.0 3.2 1.2 0.2 Iris-setosa
5.5 3.5 1.3 0.2 Iris-setosa
4.9 3.1 1.5 0.1 Iris-setosa
4.4 3.0 1.3 0.2 Iris-setosa
5.1 3.4 1.5 0.2 Iris-setosa
5.0 3.5 1.3 0.3 Iris-setosa
4.5 2.3 1.3 0.3 Iris-setosa
4.4 3.2 1.3 0.2 Iris-setosa
5.0 3.5 1.6 0.6 Iris-setosa
5.1 3.8 1.9 0.4 Iris-setosa
4.8 3.0 1.4 0.3 Iris-setosa
5.1 3.8 1.6 0.2 Iris-setosa
4.6 3.2 1.4 0.2 Iris-setosa
5.3 3.7 1.5 0.2 Iris-setosa
5.0 3.3 1.4 0.2 Iris-setosa
7.0 3.2 4.7 1.4 Iris-versicolor
6.4 3.2 4.5 1.5 Iris-versicolor
6.9 3.1 4.9 1.5 Iris-versicolor
5.5 2.3 4.0 1.3 Iris-versicolor
6.5 2.8 4.6 1.5 Iris-versicolor
5.7 2.8 4.5 1.3 Iris-versicolor
6.3 3.3 4.7 1.6 Iris-versicolor
4.9 2.4 3.3 1.0 Iris-versicolor
6.6 2.9 4.6 1.3 Iris-versicolor
5.2 2.7 3.9 1.4 Iris-versicolor
5.0 2.0 3.5 1.0 Iris-versicolor
5.9 3.0 4.2 1.5 Iris-versicolor
6.0 2.2 4.0 1.0 Iris-versicolor
6.1 2.9 4.7 1.4 Iris-versicolor
5.6 2.9 3.6 1.3 Iris-versicolor
6.7 3.1 4.4 1.4 Iris-versicolor
5.6 3.0 4.5 1.5 Iris-versicolor
5.8 2.7 4.1 1.0 Iris-versicolor
6.2 2.2 4.5 1.5 Iris-versicolor
5.6 2.5 3.9 1.1 Iris-versicolor
5.9 3.2 4.8 1.8 Iris-versicolor
6.1 2.8 4.0 1.3 Iris-versicolor
6.3 2.5 4.9 1.5 Iris-versicolor
6.1 2.8 4.7 1.2 Iris-versicolor
6.4 2.9 4.3 1.3 Iris-versicolor
6.6 3.0 4.4 1.4 Iris-versicolor
6.8 2.8 4.8 1.4 Iris-versicolor
6.7 3.0 5.0 1.7 Iris-versicolor
6.0 2.9 4.5 1.5 Iris-versicolor
5.7 2.6 3.5 1.0 Iris-versicolor
5.5 2.4 3.8 1.1 Iris-versicolor
5.5 2.4 3.7 1.0 Iris-versicolor
5.8 2.7 3.9 1.2 Iris-versicolor
6.0 2.7 5.1 1.6 Iris-versicolor
5.4 3.0 4.5 1.5 Iris-versicolor
6.0 3.4 4.5 1.6 Iris-versicolor
6.7 3.1 4.7 1.5 Iris-versicolor
6.3 2.3 4.4 1.3 Iris-versicolor
5.6 3.0 4.1 1.3 Iris-versicolor
5.5 2.5 4.0 1.3 Iris-versicolor
5.5 2.6 4.4 1.2 Iris-versicolor
6.1 3.0 4.6 1.4 Iris-versicolor
5.8 2.6 4.0 1.2 Iris-versicolor
5.0 2.3 3.3 1.0 Iris-versicolor
5.6 2.7 4.2 1.3 Iris-versicolor
5.7 3.0 4.2 1.2 Iris-versicolor
5.7 2.9 4.2 1.3 Iris-versicolor
6.2 2.9 4.3 1.3 Iris-versicolor
5.1 2.5 3.0 1.1 Iris-versicolor
5.7 2.8 4.1 1.3 Iris-versicolor
6.3 3.3 6.0 2.5 Iris-virginica
5.8 2.7 5.1 1.9 Iris-virginica
7.1 3.0 5.9 2.1 Iris-virginica
6.3 2.9 5.6 1.8 Iris-virginica
6.5 3.0 5.8 2.2 Iris-virginica
7.6 3.0 6.6 2.1 Iris-virginica
4.9 2.5 4.5 1.7 Iris-virginica
7.3 2.9 6.3 1.8 Iris-virginica
6.7 2.5 5.8 1.8 Iris-virginica
7.2 3.6 6.1 2.5 Iris-virginica
6.5 3.2 5.1 2.0 Iris-virginica
6.4 2.7 5.3 1.9 Iris-virginica
6.8 3.0 5.5 2.1 Iris-virginica
5.7 2.5 5.0 2.0 Iris-virginica
5.8 2.8 5.1 2.4 Iris-virginica
6.4 3.2 5.3 2.3 Iris-virginica
6.5 3.0 5.5 1.8 Iris-virginica
7.7 3.8 6.7 2.2 Iris-virginica
7.7 2.6 6.9 2.3 Iris-virginica
6.0 2.2 5.0 1.5 Iris-virginica
6.9 3.2 5.7 2.3 Iris-virginica
5.6 2.8 4.9 2.0 Iris-virginica
7.7 2.8 6.7 2.0 Iris-virginica
6.3 2.7 4.9 1.8 Iris-virginica
6.7 3.3 5.7 2.1 Iris-virginica
7.2 3.2 6.0 1.8 Iris-virginica
6.2 2.8 4.8 1.8 Iris-virginica
6.1 3.0 4.9 1.8 Iris-virginica
6.4 2.8 5.6 2.1 Iris-virginica
7.2 3.0 5.8 1.6 Iris-virginica
7.4 2.8 6.1 1.9 Iris-virginica
7.9 3.8 6.4 2.0 Iris-virginica
6.4 2.8 5.6 2.2 Iris-virginica
6.3 2.8 5.1 1.5 Iris-virginica
6.1 2.6 5.6 1.4 Iris-virginica
7.7 3.0 6.1 2.3 Iris-virginica
6.3 3.4 5.6 2.4 Iris-virginica
6.4 3.1 5.5 1.8 Iris-virginica
6.0 3.0 4.8 1.8 Iris-virginica
6.9 3.1 5.4 2.1 Iris-virginica
6.7 3.1 5.6 2.4 Iris-virginica
6.9 3.1 5.1 2.3 Iris-virginica
5.8 2.7 5.1 1.9 Iris-virginica
6.8 3.2 5.9 2.3 Iris-virginica
6.7 3.3 5.7 2.5 Iris-virginica
6.7 3.0 5.2 2.3 Iris-virginica
6.3 2.5 5.0 1.9 Iris-virginica
6.5 3.0 5.2 2.0 Iris-virginica
6.2 3.4 5.4 2.3 Iris-virginica
5.9 3.0 5.1 1.8 Iris-virginica
package org.knuchel.playground;
import java.io.BufferedReader;
import java.io.FileNotFoundException;
import java.io.FileReader;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Collections;
import java.util.Comparator;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Map.Entry;
// K-nearest neighbor
/*
0.0 < sepalLength < 7.9
0.0 < sepalWidth < 4.4
0.0 < petalLength < 6.9
0.0 < petalWidth < 2.5
*/
public class KNNJava {
public static void main(String[] args) {
List<Iris> dataset = loadDataset("data/iris.data.csv");
Integer k = 5;
Double sepalLength = 5.7d, sepalWidth = 2.6d, petalLength = 3.5d, petalWidth = 1d;
String predictedSpecie = predictSpecie(dataset, k, sepalLength, sepalWidth, petalLength, petalWidth);
System.out.println("Iris with [sepalLength=" + sepalLength + ", sepalWidth=" + sepalWidth + ", petalLength=" + petalLength + ", petalWidth="
+ petalWidth + "] should be a " + predictedSpecie);
// try different values of k and see how prediction errors change
// evaluate(dataset);
}
public static String predictSpecie(List<Iris> dataset, Integer k, Double sepalLength, Double sepalWidth, Double petalLength, Double petalWidth) {
// calculate distance for each sample in dataset
Iris unknownIris = new Iris(sepalLength, sepalWidth, petalLength, petalWidth, null);
List<Score> scores = new ArrayList<Score>();
for (Iris iris : dataset) {
scores.add(new Score(unknownIris.distance(iris), iris.specie));
}
Collections.sort(scores, Score.COMPARATOR);
// count occurences for K nearest neighbor
Map<String, Integer> occurenceCount = new HashMap<String, Integer>();
for (Integer i = 0; i < scores.size(); i++) {
String specie = scores.get(i).specie;
if (occurenceCount.containsKey(specie)) {
occurenceCount.put(specie, occurenceCount.get(specie) + 1);
} else {
occurenceCount.put(specie, 1);
}
if (i >= k - 1) {
break;
}
}
// find the most frequent occurence
String mostFrequentSpecie = null;
Integer nbOccurence = 0;
for (Entry<String, Integer> entry : occurenceCount.entrySet()) {
if (nbOccurence < entry.getValue()) {
nbOccurence = entry.getValue();
mostFrequentSpecie = entry.getKey();
}
}
return mostFrequentSpecie;
}
public static List<Iris> loadDataset(String csvFile) {
List<Iris> dataset = new ArrayList<Iris>();
BufferedReader br = null;
String line = "";
String cvsSplitBy = ",";
try {
br = new BufferedReader(new FileReader(csvFile));
while ((line = br.readLine()) != null) {
if (line.length() > 0) {
String[] cell = line.split(cvsSplitBy);
dataset.add(new Iris(Double.parseDouble(cell[0]), Double.parseDouble(cell[1]), Double.parseDouble(cell[2]), Double.parseDouble(cell[3]),
cell[4]));
}
}
} catch (FileNotFoundException e) {
e.printStackTrace();
} catch (IOException e) {
e.printStackTrace();
} finally {
if (br != null) {
try {
br.close();
} catch (IOException e) {
e.printStackTrace();
}
}
}
return dataset;
}
static class Score {
public static final Comparator<Score> COMPARATOR = new Comparator<Score>() {
@Override
public int compare(Score o1, Score o2) {
return o1.score.compareTo(o2.score);
}
};
public Double score;
public String specie;
public Score(Double score, String specie) {
this.score = score;
this.specie = specie;
}
}
static class Iris {
public Double sepalLength;
public Double sepalWidth;
public Double petalLength;
public Double petalWidth;
public String specie;
public Iris(Double sepalLength, Double sepalWidth, Double petalLength, Double petalWidth, String specie) {
this.sepalLength = sepalLength;
this.sepalWidth = sepalWidth;
this.petalLength = petalLength;
this.petalWidth = petalWidth;
this.specie = specie;
}
public Double distance(Iris that) {
return Math.sqrt(Math.pow(sepalLength - that.sepalLength, 2) + Math.pow(sepalWidth - that.sepalWidth, 2)
+ Math.pow(petalLength - that.petalLength, 2) + Math.pow(petalWidth - that.petalWidth, 2));
}
@Override
public String toString() {
return "Iris [specie=" + specie + ", sepalLength=" + sepalLength + ", sepalWidth=" + sepalWidth + ", petalLength=" + petalLength + ", petalWidth="
+ petalWidth + "]";
}
}
public static void evaluate(List<Iris> dataset) {
// split dataset in 2 parts : one part to learn, the other part to test
List<Iris> versicolor = new ArrayList<Iris>(), virginica = new ArrayList<Iris>(), setosa = new ArrayList<Iris>();
for (Iris iris : dataset) {
if (iris.specie.equals("Iris-versicolor"))
versicolor.add(iris);
else if (iris.specie.equals("Iris-virginica"))
virginica.add(iris);
else if (iris.specie.equals("Iris-setosa"))
setosa.add(iris);
}
Collections.shuffle(versicolor);
Collections.shuffle(virginica);
Collections.shuffle(setosa);
List<Iris> learningData = new ArrayList<Iris>(), testData = new ArrayList<Iris>();
for (Integer i = 0; i < versicolor.size(); i++) {
if (i < versicolor.size() / 2)
learningData.add(versicolor.get(i));
else
testData.add(versicolor.get(i));
}
for (Integer i = 0; i < virginica.size(); i++) {
if (i < virginica.size() / 2)
learningData.add(virginica.get(i));
else
testData.add(virginica.get(i));
}
for (Integer i = 0; i < setosa.size(); i++) {
if (i < setosa.size() / 2)
learningData.add(setosa.get(i));
else
testData.add(setosa.get(i));
}
// for each value of k, count the number of errors
for (Integer k = 1; k <= 20; k++) {
Integer cpt = 0;
for (Iris iris : testData) {
if (!predictSpecie(learningData, k, iris.sepalLength, iris.sepalWidth, iris.petalLength, iris.petalWidth).equals(iris.specie)) {
cpt++;
}
}
System.out.println(cpt + " errors on " + testData.size() + " tests with k=" + k);
}
}
}
package org.knuchel.playground
import scala.io.Source
import scala.util.Random
// K-nearest neighbor
/*
0.0 < sepalLength < 7.9
0.0 < sepalWidth < 4.4
0.0 < petalLength < 6.9
0.0 < petalWidth < 2.5
*/
object KNNScala {
def main(args: Array[String]) {
val dataset = loadDataset("data/iris.data.csv")
val k = 5
val features = (5.7, 2.6, 3.5, 1d)
val predictedSpecie = predictSpecie(dataset, k, features)
println("Iris with [sepalLength=" + features._1 + ", sepalWidth=" + features._2 + ", petalLength=" + features._3 + ", petalWidth=" + features._4 + "] should be a " + predictedSpecie);
// try different values of k and see how prediction errors change
// evaluate(dataset)
}
def predictSpecie(dataset: List[Iris], k: Int, features: (Double, Double, Double, Double)) = {
dataset
.map(iris => (iris.distance(features), iris.getSpecie)).sorted.take(k) // calculate distance for each sample in dataset, sort by distance and take K nearest
.groupBy(_._2).map(elt => (elt._2.length, elt._1)).toList // group by specie, count occurences of species and transform map to list
.sortBy(-_._1).head._2 // sort descending by number of specie occurences, get the first one and return the specie name
}
def loadDataset(csvFile: String) = {
val file = Source.fromFile(csvFile)
val iter = file.getLines().filter(s => s.length() > 0).map(line => {
val cell = line.split(",")
new Iris(cell(0).toDouble, cell(1).toDouble, cell(2).toDouble, cell(3).toDouble, cell(4))
}).toList
file.close()
iter
}
class Iris(sepalLength: Double, sepalWidth: Double, petalLength: Double, petalWidth: Double, specie: String) {
def getSpecie = specie
def getFeatures = (sepalLength, sepalWidth, petalLength, petalWidth)
def distance(that: (Double, Double, Double, Double)) = Math.sqrt(Math.pow(this.sepalLength - that._1, 2) + Math.pow(this.sepalWidth - that._2, 2) + Math.pow(this.petalLength - that._3, 2) + Math.pow(this.petalWidth - that._4, 2))
// override def toString = "Iris"
override def toString = "Iris [specie=" + specie + ", sepalLength=" + sepalLength + ", sepalWidth=" + sepalWidth + ", petalLength=" + petalLength + ", petalWidth=" + petalWidth + "]"
}
def evaluate(dataset: List[Iris]) {
// split dataset in 2 parts : one part to learn, the other part to test
val versicolor = Random.shuffle(dataset.filter(iris => iris.getSpecie == "Iris-versicolor"))
val virginica = Random.shuffle(dataset.filter(iris => iris.getSpecie == "Iris-virginica"))
val setosa = Random.shuffle(dataset.filter(iris => iris.getSpecie == "Iris-setosa"))
val learningData = versicolor.take(versicolor.length / 2) ++ virginica.take(virginica.length / 2) ++ setosa.take(setosa.length / 2)
val testData = versicolor.drop(versicolor.length / 2) ++ virginica.drop(virginica.length / 2) ++ setosa.drop(setosa.length / 2)
// test errors for different values of k
for (k <- (1 to 20)) {
println(testData.filter(iris => iris.getSpecie != predictSpecie(learningData, k, iris.getFeatures)).length + " errors on " + testData.length + " tests with k=" + k);
}
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment