Created
September 7, 2017 14:04
-
-
Save FlorianCassayre/b9ebaaf2c400b4aee18f09aac47c8736 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
final int INITIAL_K = 0; | |
int k = INITIAL_K; | |
final int POINTS_PER_CLUSTER = 100; | |
final float POINTS_DEVIATION = 50; | |
final int INTERVAL = 100; | |
final ArrayList<Point> points = new ArrayList<Point>(); | |
final ArrayList<Point> centroids = new ArrayList<Point>(); | |
final HashMap<Point, Point> map = new HashMap<Point, Point>(); | |
boolean showCentroids = false; | |
long lastT = millis(); | |
void setup() | |
{ | |
size(1000, 1000); | |
surface.setTitle("Algorithme k-means"); | |
initialize(); | |
iterate(); | |
colorMode(HSB, 1.0, 1.0, 1.0); | |
frameRate(60); | |
step1(); | |
draw(); | |
} | |
void initialize() | |
{ | |
k = INITIAL_K; | |
points.clear(); | |
centroids.clear(); | |
map.clear(); | |
for(int i = 0; i < INITIAL_K; i++) | |
{ | |
final Point cluster = new Point(random(width - 2 * POINTS_DEVIATION) + POINTS_DEVIATION, random(height - 2 * POINTS_DEVIATION) + POINTS_DEVIATION); | |
addCluster(cluster); | |
} | |
randomCentroids(); | |
} | |
void randomCentroids() | |
{ | |
centroids.clear(); | |
for(int i = 0; i < k; i++) | |
{ | |
final Point rnd = points.get(floor(random(points.size()))); | |
centroids.add(new Point(rnd.x, rnd.y)); | |
} | |
} | |
void addCluster(Point point) | |
{ | |
k++; | |
for(int j = 0; j < POINTS_PER_CLUSTER; j++) | |
{ | |
points.add(point.relative(randomGaussian() * POINTS_DEVIATION, randomGaussian() * POINTS_DEVIATION)); | |
} | |
} | |
void keyPressed() | |
{ | |
if(key == 'c') | |
initialize(); | |
else if(key == 'v') | |
{ | |
centroids.clear(); | |
for(int i = 0; i < k; i++) | |
{ | |
final Point rnd = points.get(floor(random(points.size()))); | |
centroids.add(new Point(rnd.x, rnd.y)); | |
} | |
step1(); | |
} | |
else if(key == 'b') | |
showCentroids = !showCentroids; | |
} | |
void mousePressed() | |
{ | |
addCluster(new Point(mouseX, mouseY)); | |
randomCentroids(); | |
} | |
void draw() | |
{ | |
background(0.0, 0.0, 0.0); | |
noStroke(); | |
for(int i = 0; i < k; i++) | |
{ | |
final Point centroid = centroids.get(i); | |
fill((float) i / k, 1.0, 1.0); | |
for(Point point : points) | |
{ | |
if(map.get(point) == centroid) | |
{ | |
ellipse(point.x, point.y, 5, 5); | |
} | |
} | |
} | |
if(showCentroids) | |
{ | |
strokeWeight(2.0); | |
stroke(0.0, 0.0, 1.0); | |
for(int i = 0; i < k; i++) | |
{ | |
final Point centroid = centroids.get(i); | |
fill((float) i / k, 1.0, 1.0); | |
ellipse(centroid.x, centroid.y, 20, 20); | |
} | |
} | |
fill(0.0, 0.0, 1.0); | |
text("C : Tout effacer", 10, 5, 500, 20); | |
text("V : Relancer l'algorithme k-means", 10, 20, 500, 20); | |
text("B : Afficher/Cacher les centres", 10, 35, 500, 20); | |
text("Clic souris : Ajouter un nuage de points", 10, 50, 500, 20); | |
if(millis() - lastT >= INTERVAL) | |
{ | |
lastT = millis(); | |
iterate(); | |
} | |
} | |
void step1() | |
{ | |
map.clear(); | |
// Step 1 | |
for(Point point : points) | |
{ | |
Point nearest = null; | |
float distance = Float.NaN; | |
for(Point center : centroids) | |
{ | |
float d = point.distanceSq(center); | |
if(nearest == null || d < distance) | |
{ | |
nearest = center; | |
distance = d; | |
} | |
} | |
map.put(point, nearest); | |
} | |
} | |
void step2() | |
{ | |
// Step 2 | |
for(Point center : centroids) | |
{ | |
float mx = 0, my = 0; | |
int n = 0; | |
for(Point point : points) | |
{ | |
if(map.get(point) == center) | |
{ | |
mx += point.x; | |
my += point.y; | |
n++; | |
} | |
} | |
center.x = mx / n; | |
center.y = my / n; | |
} | |
} | |
void iterate() | |
{ | |
step1(); | |
step2(); | |
} | |
class Point | |
{ | |
private float x, y; | |
public Point(float x, float y) | |
{ | |
this.x = x; | |
this.y = y; | |
} | |
public Point relative(float rx, float ry) | |
{ | |
return new Point(x + rx, y + ry); | |
} | |
public float distanceSq(Point point) | |
{ | |
return sq(x - point.x) + sq(y - point.y); | |
} | |
public float distance(Point point) | |
{ | |
return sqrt(distanceSq(point)); | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment