Skip to content

Instantly share code, notes, and snippets.

@AliZafar120
Created October 24, 2017 12:57
Show Gist options
  • Star 4 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save AliZafar120/baa9448e8d081f138e19a126957ecd34 to your computer and use it in GitHub Desktop.
Save AliZafar120/baa9448e8d081f138e19a126957ecd34 to your computer and use it in GitHub Desktop.
K means clustering image segmentation
package lab8;
import java.awt.BorderLayout;
import java.awt.image.BufferedImage;
import java.io.File;
import java.util.Arrays;
import javax.imageio.ImageIO;
import javax.swing.ImageIcon;
import javax.swing.JFrame;
import javax.swing.JLabel;
import javax.swing.JPanel;
public class KMeans {
BufferedImage original;
BufferedImage result;
Cluster[] clusters;
public static final int MODE_CONTINUOUS = 1;
public static final int MODE_ITERATIVE = 2;
public static void main(String[] args) {
// parse arguments
String src = "C:\\Users\\User\\Desktop\\Image\\Input\\Lenna.png";
int k = 15;
String m = "-1";
int mode = 1;
if (m.equals("-c")) {
mode = MODE_ITERATIVE;
} else if (m.equals("-c")) {
mode = MODE_CONTINUOUS;
}
// create new KMeans object
KMeans kmeans = new KMeans();
// call the function to actually start the clustering
BufferedImage dstImage = kmeans.calculate(loadImage(src),
k,mode);
// save the resulting image ]
JFrame frame = new JFrame();
JLabel lblimage = new JLabel(new ImageIcon(dstImage));
JPanel mainPanel = new JPanel(new BorderLayout());
mainPanel.add(lblimage);
// add more components here
frame.add(mainPanel);
frame.setVisible(true);
frame.setSize(300, 400);
//saveImage(dst, dstImage);
}
public KMeans() { }
public BufferedImage calculate(BufferedImage image,
int k, int mode) {
long start = System.currentTimeMillis();
int w = image.getWidth();
int h = image.getHeight();
// create clusters
clusters = createClusters(image,k);
// create cluster lookup table
int[] lut = new int[w*h];
Arrays.fill(lut, -1);
// at first loop all pixels will move their clusters
boolean pixelChangedCluster = true;
// loop until all clusters are stable!
int loops = 0;
while (pixelChangedCluster) {
pixelChangedCluster = false;
loops++;
for (int y=0;y<h;y++) {
for (int x=0;x<w;x++) {
int pixel = image.getRGB(x, y);
Cluster cluster = findMinimalCluster(pixel);
if (lut[w*y+x]!=cluster.getId()) {
// cluster changed
if (mode==MODE_CONTINUOUS) {
if (lut[w*y+x]!=-1) {
// remove from possible previous
// cluster
clusters[lut[w*y+x]].removePixel(
pixel);
}
// add pixel to cluster
cluster.addPixel(pixel);
}
// continue looping
pixelChangedCluster = true;
// update lut
lut[w*y+x] = cluster.getId();
}
}
}
if (mode==MODE_ITERATIVE) {
// update clusters
for (int i=0;i<clusters.length;i++) {
clusters[i].clear();
}
for (int y=0;y<h;y++) {
for (int x=0;x<w;x++) {
int clusterId = lut[w*y+x];
// add pixels to cluster
clusters[clusterId].addPixel(
image.getRGB(x, y));
}
}
}
}
// create result image
BufferedImage result = new BufferedImage(w, h,
BufferedImage.TYPE_INT_RGB);
for (int y=0;y<h;y++) {
for (int x=0;x<w;x++) {
int clusterId = lut[w*y+x];
result.setRGB(x, y, clusters[clusterId].getRGB());
}
}
long end = System.currentTimeMillis();
System.out.println("Clustered to "+k
+ " clusters in "+loops
+" loops in "+(end-start)+" ms.");
return result;
}
public Cluster[] createClusters(BufferedImage image, int k) {
// Here the clusters are taken with specific steps,
// so the result looks always same with same image.
// You can randomize the cluster centers, if you like.
Cluster[] result = new Cluster[k];
int x = 0; int y = 0;
int dx = image.getWidth()/k;
int dy = image.getHeight()/k;
for (int i=0;i<k;i++) {
result[i] = new Cluster(i,image.getRGB(x, y));
x+=dx; y+=dy;
}
return result;
}
public Cluster findMinimalCluster(int rgb) {
Cluster cluster = null;
int min = Integer.MAX_VALUE;
for (int i=0;i<clusters.length;i++) {
int distance = clusters[i].distance(rgb);
if (distance<min) {
min = distance;
cluster = clusters[i];
}
}
return cluster;
}
public static void saveImage(String filename,
BufferedImage image) {
File file = new File(filename);
try {
ImageIO.write(image, "png", file);
} catch (Exception e) {
System.out.println(e.toString()+" Image '"+filename
+"' saving failed.");
}
}
public static BufferedImage loadImage(String filename) {
BufferedImage result = null;
try {
result = ImageIO.read(new File(filename));
} catch (Exception e) {
System.out.println(e.toString()+" Image '"
+filename+"' not found.");
}
return result;
}
class Cluster {
int id;
int pixelCount;
int red;
int green;
int blue;
int reds;
int greens;
int blues;
public Cluster(int id, int rgb) {
int r = rgb>>16&0x000000FF;
int g = rgb>> 8&0x000000FF;
int b = rgb>> 0&0x000000FF;
red = r;
green = g;
blue = b;
this.id = id;
addPixel(rgb);
}
public void clear() {
red = 0;
green = 0;
blue = 0;
reds = 0;
greens = 0;
blues = 0;
pixelCount = 0;
}
int getId() {
return id;
}
int getRGB() {
int r = reds / pixelCount;
int g = greens / pixelCount;
int b = blues / pixelCount;
return 0xff000000|r<<16|g<<8|b;
}
void addPixel(int color) {
int r = color>>16&0x000000FF;
int g = color>> 8&0x000000FF;
int b = color>> 0&0x000000FF;
reds+=r;
greens+=g;
blues+=b;
pixelCount++;
red = reds/pixelCount;
green = greens/pixelCount;
blue = blues/pixelCount;
}
void removePixel(int color) {
int r = color>>16&0x000000FF;
int g = color>> 8&0x000000FF;
int b = color>> 0&0x000000FF;
reds-=r;
greens-=g;
blues-=b;
pixelCount--;
red = reds/pixelCount;
green = greens/pixelCount;
blue = blues/pixelCount;
}
int distance(int color) {
int r = color>>16&0x000000FF;
int g = color>> 8&0x000000FF;
int b = color>> 0&0x000000FF;
int rx = Math.abs(red-r);
int gx = Math.abs(green-g);
int bx = Math.abs(blue-b);
int d = (rx+gx+bx) / 3;
return d;
}
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment