Skip to content

Instantly share code, notes, and snippets.

@thomasjungblut
Last active November 6, 2016 10:39
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 1 You must be signed in to fork a gist
  • Save thomasjungblut/5146284 to your computer and use it in GitHub Desktop.
Save thomasjungblut/5146284 to your computer and use it in GitHub Desktop.
Image Segmentation using mean shift clustering
package de.jungblut.ml;
import java.awt.Color;
import java.awt.FlowLayout;
import java.awt.image.BufferedImage;
import java.io.File;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import javax.imageio.ImageIO;
import javax.swing.ImageIcon;
import javax.swing.JFrame;
import javax.swing.JLabel;
import javax.swing.JPanel;
import de.jungblut.clustering.MeanShiftClustering;
import de.jungblut.distance.EuclidianDistance;
import de.jungblut.math.DoubleVector;
import de.jungblut.math.dense.DenseDoubleVector;
import de.jungblut.reader.ImageReader;
public class PictureSegmentation {
private static final String img = "/tmp/Lenna.png";
private static final EuclidianDistance DISTANCE = new EuclidianDistance();
public static void main(String[] args) throws Exception {
BufferedImage readImg = ImageIO.read(new File(img));
DoubleVector[] luv = ImageReader.readImageAsLUV(readImg);
DoubleVector[] rgb = ImageReader.readImageAsRGB(readImg);
List<DoubleVector> cluster = MeanShiftClustering.cluster(
Arrays.asList(luv), 5.0, 10.0, 50, true);
System.out.println(cluster.size());
List<Color> colors = pick(cluster.size());
final JFrame frame = new JFrame("Segmentation");
final JPanel panel = new JPanel();
frame.setLayout(new FlowLayout());
frame.setLocation(500, 250);
frame.add(panel);
final JLabel pictureLabel = new JLabel(new ImageIcon(getImage(readImg,
cluster, luv, colors)));
panel.add(pictureLabel);
final JLabel pictureLabel2 = new JLabel(new ImageIcon(getRawImage(readImg,
rgb)));
panel.add(pictureLabel2);
frame.setDefaultCloseOperation(JFrame.EXIT_ON_CLOSE);
frame.pack();
frame.setVisible(true);
}
private static BufferedImage getRawImage(BufferedImage readImg,
DoubleVector[] rgb) {
BufferedImage img = new BufferedImage(readImg.getWidth(),
readImg.getHeight(), BufferedImage.TYPE_INT_RGB);
int column = 0;
int row = 0;
for (int i = 0; i < rgb.length; i++) {
DoubleVector rgbVector = rgb[i];
Color col = new Color((int) rgbVector.get(0), (int) rgbVector.get(1),
(int) rgbVector.get(2));
img.setRGB(column, row, col.getRGB());
column++;
if (column == readImg.getWidth()) {
column = 0;
row++;
}
}
return img;
}
private static BufferedImage getImage(BufferedImage readImg,
List<DoubleVector> cluster, DoubleVector[] luv, List<Color> colors) {
// assign each cluster a shade, thus rerendering the image with each
// pixel classified to its nearest center
BufferedImage img = new BufferedImage(readImg.getWidth(),
readImg.getHeight(), BufferedImage.TYPE_INT_RGB);
int column = 0;
int row = 0;
for (int i = 0; i < luv.length; i++) {
int nearest = nearest(luv[i], cluster);
img.setRGB(column, row, colors.get(nearest).getRGB());
column++;
if (column == readImg.getWidth()) {
column = 0;
row++;
}
}
return img;
}
private static int nearest(DoubleVector doubleVector,
List<DoubleVector> cluster) {
DenseDoubleVector vector = new DenseDoubleVector(cluster.size());
for (int i = 0; i < vector.getLength(); i++) {
vector.set(i, DISTANCE.measureDistance(doubleVector, cluster.get(i)));
}
return vector.minIndex();
}
private static List<Color> pick(int num) {
List<Color> colors = new ArrayList<>();
if (num < 2)
return colors;
float dx = 1.0f / (num - 1);
for (int i = 0; i < num; i++) {
colors.add(get(i * dx));
}
return colors;
}
private static Color get(float x) {
float r = 0.0f;
float g = 0.0f;
float b = 1.0f;
if (x >= 0.0f && x < 0.2f) {
x = x / 0.2f;
r = 0.0f;
g = x;
b = 1.0f;
} else if (x >= 0.2f && x < 0.4f) {
x = (x - 0.2f) / 0.2f;
r = 0.0f;
g = 1.0f;
b = 1.0f - x;
} else if (x >= 0.4f && x < 0.6f) {
x = (x - 0.4f) / 0.2f;
r = x;
g = 1.0f;
b = 0.0f;
} else if (x >= 0.6f && x < 0.8f) {
x = (x - 0.6f) / 0.2f;
r = 1.0f;
g = 1.0f - x;
b = 0.0f;
} else if (x >= 0.8f && x <= 1.0f) {
x = (x - 0.8f) / 0.2f;
r = 1.0f;
g = 0.0f;
b = x;
}
return new Color(r, g, b);
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment