Skip to content

Instantly share code, notes, and snippets.

@FlorianCassayre
Created September 7, 2017 14:04
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save FlorianCassayre/b9ebaaf2c400b4aee18f09aac47c8736 to your computer and use it in GitHub Desktop.
Save FlorianCassayre/b9ebaaf2c400b4aee18f09aac47c8736 to your computer and use it in GitHub Desktop.
final int INITIAL_K = 0;
int k = INITIAL_K;
final int POINTS_PER_CLUSTER = 100;
final float POINTS_DEVIATION = 50;
final int INTERVAL = 100;
final ArrayList<Point> points = new ArrayList<Point>();
final ArrayList<Point> centroids = new ArrayList<Point>();
final HashMap<Point, Point> map = new HashMap<Point, Point>();
boolean showCentroids = false;
long lastT = millis();
void setup()
{
size(1000, 1000);
surface.setTitle("Algorithme k-means");
initialize();
iterate();
colorMode(HSB, 1.0, 1.0, 1.0);
frameRate(60);
step1();
draw();
}
void initialize()
{
k = INITIAL_K;
points.clear();
centroids.clear();
map.clear();
for(int i = 0; i < INITIAL_K; i++)
{
final Point cluster = new Point(random(width - 2 * POINTS_DEVIATION) + POINTS_DEVIATION, random(height - 2 * POINTS_DEVIATION) + POINTS_DEVIATION);
addCluster(cluster);
}
randomCentroids();
}
void randomCentroids()
{
centroids.clear();
for(int i = 0; i < k; i++)
{
final Point rnd = points.get(floor(random(points.size())));
centroids.add(new Point(rnd.x, rnd.y));
}
}
void addCluster(Point point)
{
k++;
for(int j = 0; j < POINTS_PER_CLUSTER; j++)
{
points.add(point.relative(randomGaussian() * POINTS_DEVIATION, randomGaussian() * POINTS_DEVIATION));
}
}
void keyPressed()
{
if(key == 'c')
initialize();
else if(key == 'v')
{
centroids.clear();
for(int i = 0; i < k; i++)
{
final Point rnd = points.get(floor(random(points.size())));
centroids.add(new Point(rnd.x, rnd.y));
}
step1();
}
else if(key == 'b')
showCentroids = !showCentroids;
}
void mousePressed()
{
addCluster(new Point(mouseX, mouseY));
randomCentroids();
}
void draw()
{
background(0.0, 0.0, 0.0);
noStroke();
for(int i = 0; i < k; i++)
{
final Point centroid = centroids.get(i);
fill((float) i / k, 1.0, 1.0);
for(Point point : points)
{
if(map.get(point) == centroid)
{
ellipse(point.x, point.y, 5, 5);
}
}
}
if(showCentroids)
{
strokeWeight(2.0);
stroke(0.0, 0.0, 1.0);
for(int i = 0; i < k; i++)
{
final Point centroid = centroids.get(i);
fill((float) i / k, 1.0, 1.0);
ellipse(centroid.x, centroid.y, 20, 20);
}
}
fill(0.0, 0.0, 1.0);
text("C : Tout effacer", 10, 5, 500, 20);
text("V : Relancer l'algorithme k-means", 10, 20, 500, 20);
text("B : Afficher/Cacher les centres", 10, 35, 500, 20);
text("Clic souris : Ajouter un nuage de points", 10, 50, 500, 20);
if(millis() - lastT >= INTERVAL)
{
lastT = millis();
iterate();
}
}
void step1()
{
map.clear();
// Step 1
for(Point point : points)
{
Point nearest = null;
float distance = Float.NaN;
for(Point center : centroids)
{
float d = point.distanceSq(center);
if(nearest == null || d < distance)
{
nearest = center;
distance = d;
}
}
map.put(point, nearest);
}
}
void step2()
{
// Step 2
for(Point center : centroids)
{
float mx = 0, my = 0;
int n = 0;
for(Point point : points)
{
if(map.get(point) == center)
{
mx += point.x;
my += point.y;
n++;
}
}
center.x = mx / n;
center.y = my / n;
}
}
void iterate()
{
step1();
step2();
}
class Point
{
private float x, y;
public Point(float x, float y)
{
this.x = x;
this.y = y;
}
public Point relative(float rx, float ry)
{
return new Point(x + rx, y + ry);
}
public float distanceSq(Point point)
{
return sq(x - point.x) + sq(y - point.y);
}
public float distance(Point point)
{
return sqrt(distanceSq(point));
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment