Skip to content

Instantly share code, notes, and snippets.

@bmander
Created May 9, 2010 07:06
Show Gist options
  • Save bmander/394995 to your computer and use it in GitHub Desktop.
Save bmander/394995 to your computer and use it in GitHub Desktop.
int NGROUPS=5;
float INFINITY = 100000000;
int RANDOM_POINTS = 75;
boolean USE_RANDOM_POINTS = true;
float TRUCK_SLACK = 1.1;
class PointGroup {
ArrayList points;
PVector mean;
PointGroup() {
points = new ArrayList();
}
int size() {
return this.points.size();
}
void refreshMean() {
PVector sum = new PVector(0,0);
for(int i=0; i<points.size(); i++) {
Point pt = (Point)this.points.get(i);
sum = PVector.add( sum, pt );
}
this.mean = PVector.div( sum, this.points.size() );
}
void add(Point pt) {
points.add( pt );
refreshMean();
}
float dist( Point pt ) {
return PVector.dist( pt, this.mean );
}
float sumOfSquares() {
float ret = 0;
for(int i=0; i<this.points.size(); i++) {
Point pt = (Point)this.points.get(i);
ret += sq(PVector.dist( pt, this.mean));
}
return ret;
}
void draw() {
ellipse( mean.x, mean.y, 10, 10 );
for(int i=0; i<this.points.size(); i++) {
Point pt = (Point)this.points.get(i);
line( mean.x, mean.y, pt.x, pt.y );
}
}
}
class KMeansAlgorithm {
ArrayList points;
ArrayList groups;
int groupSize;
Point removeRandom() {
if( this.points.size() == 0 ) {
return null;
}
int randIndex = int( random(this.points.size()) );
Point pt = (Point)this.points.remove( randIndex );
return pt;
}
KMeansAlgorithm(ArrayList points, int nGroups) {
this.groupSize = int(TRUCK_SLACK*(points.size()/nGroups));
this.groups = new ArrayList();
for(int i=0; i<nGroups; i++) {
this.groups.add( new PointGroup() );
}
this.points = (ArrayList)points.clone();
for(int i=0; i<nGroups; i++) {
int randIndex = int( random(this.points.size()) );
Point pt = removeRandom();
((PointGroup)this.groups.get(i)).add( pt );
}
}
boolean iterate() {
// get random point
Point pt = removeRandom();
if( pt == null ) {
return false;
}
// compare it to every group
float winner_dist = INFINITY;
PointGroup winner = null;
for(int i=0; i<this.groups.size(); i++) {
PointGroup pg = (PointGroup)this.groups.get(i);
// if the group is full, move on
if( pg.size() == this.groupSize ) {
continue;
}
float candidate_dist = pg.dist( pt );
if( candidate_dist < winner_dist ) {
winner_dist = candidate_dist;
winner = pg;
}
}
// group with minimum
winner.add( pt );
return true;
}
void run() {
boolean cont = true;
do {
cont = iterate();
} while( cont );
}
/* find the sum of sum of squares score */
float score() {
float ret = 0;
for(int i=0; i<groups.size(); i++) {
PointGroup pg = (PointGroup)groups.get(i);
ret += pg.sumOfSquares();
}
return ret;
}
void draw() {
for(int i=0; i<groups.size(); i++) {
PointGroup pg = (PointGroup)groups.get(i);
pg.draw();
}
}
}
class Point extends PVector {
Point(int x, int y) {
super(x,y);
}
void draw() {
point( this.x, this.y );
}
}
ArrayList points;
KMeansAlgorithm algo = null;
KMeansAlgorithm winnerAlgo = null;
float winnerScore = INFINITY;
void setup() {
size(600,600);
smooth();
noFill();
points = new ArrayList();
if(USE_RANDOM_POINTS) {
for(int i=0; i<RANDOM_POINTS; i++) {
points.add( new Point( int(random(width)), int(random(height)) ) );
}
}
}
void draw() {
stroke(0);
background(255);
strokeWeight(4);
Iterator itr = points.iterator();
while( itr.hasNext() ) {
Point pt = (Point)itr.next();
pt.draw();
}
strokeWeight(1);
if( algo != null ) {
algo.draw();
}
stroke(128,0,0);
if( winnerAlgo != null ) {
winnerAlgo.draw();
}
if(algo != null) {
spinAlgo();
}
}
void mousePressed() {
points.add( new Point( mouseX, mouseY ) );
}
void spinAlgo() {
algo = new KMeansAlgorithm( points, NGROUPS );
algo.run();
float candidateScore = algo.score();
if( candidateScore < winnerScore ) {
winnerScore = candidateScore;
winnerAlgo = algo;
}
}
void keyPressed() {
if( key == 'k' ) {
algo = new KMeansAlgorithm( points, NGROUPS );
} else if( key == 'i') {
if(algo != null) {
algo.iterate();
}
} else if( key == 's' ) {
spinAlgo();
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment