Skip to content

Instantly share code, notes, and snippets.

@rsiemens
Created February 22, 2024 22:03
Show Gist options
  • Save rsiemens/45351f09b5e447fffe022908a5e6c111 to your computer and use it in GitHub Desktop.
Save rsiemens/45351f09b5e447fffe022908a5e6c111 to your computer and use it in GitHub Desktop.
package autohunt;
import java.util.ArrayList;
import java.util.HashMap;
import java.io.File;
import java.io.IOException;
import java.awt.Rectangle;
import javax.imageio.ImageIO;
import java.awt.image.BufferedImage;
import java.awt.Graphics;
class NeedleFind {
class NeedleImage {
public final BufferedImage img;
public final int[][][] rgbCache;
public final int maxDiff;
public final int w;
public final int h;
public NeedleImage(BufferedImage img) {
this.img = img;
w = img.getWidth();
h = img.getHeight();
maxDiff = w * h * 255 * 3;
rgbCache = new int[w][h][3];
for (int x = 0; x < w; x++) {
for (int y = 0; y < h; y++) {
int rgb = img.getRGB(x, y);
rgbCache[x][y][0] = rgb >> 16 & 0xFF;
rgbCache[x][y][1] = rgb >> 8 & 0xFF;
rgbCache[x][y][2] = rgb >> 0 & 0xFF;
}
}
}
}
NeedleImage[] needlePyramid;
int levels;
public NeedleFind(String needlePath, int levels) throws IOException {
BufferedImage needle = ImageIO.read(new File(needlePath));
needlePyramid = new NeedleImage[levels];
this.levels = levels;
for (int i = 0; i < levels; i++) {
if (i == 0) {
needlePyramid[i] = new NeedleImage(needle);
} else {
needle = downSample(needlePyramid[i - 1].img, 2);
needlePyramid[i] = new NeedleImage(needle);
}
}
}
public ArrayList<Rectangle> find(BufferedImage haystack, double confidence) {
BufferedImage[] haystackPyramid = new BufferedImage[levels];
haystackPyramid[0] = haystack;
for (int i = 1; i < levels; i++) {
haystackPyramid[i] = downSample(haystackPyramid[i - 1], 2);
}
int x = 0;
int y = 0;
int w = haystackPyramid[levels - 1].getWidth();
int h = haystackPyramid[levels - 1].getHeight();
// start at the lowest resolution and match down
ArrayList<Rectangle> found = find(haystackPyramid[levels - 1], needlePyramid[levels - 1], x, y, w, h, confidence);
ArrayList<Rectangle> nextFound = new ArrayList<>();
for (int i = levels - 2; i >= 0; i--) {
for (Rectangle r: found) {
x = (int)Math.floor(r.x * 2 - 0.1 * needlePyramid[i].w);
y = (int)Math.floor(r.y * 2 - 0.1 * needlePyramid[i].h);
w = (int)Math.ceil(r.x * 2 + 1.1 * needlePyramid[i].w);
h = (int)Math.ceil(r.y * 2 + 1.1 * needlePyramid[i].h);
nextFound.addAll(find(haystackPyramid[i], needlePyramid[i], x, y, w, h, confidence));
}
found = nextFound;
nextFound = new ArrayList<>();
}
return found;
}
private ArrayList<Rectangle> find(
BufferedImage haystack,
NeedleImage needle,
int xs,
int ys,
int w,
int h,
double confidence
) {
ArrayList<Rectangle> rects = new ArrayList<>();
xs = Math.max(xs, 0);
ys = Math.max(ys, 0);
w = Math.min(w, haystack.getWidth());
h = Math.min(h, haystack.getHeight());
for (int x = xs; x < w - needle.w; x++) {
for (int y = ys; y < h - needle.h; y++) {
double SAD = 0.0;
for (int nx = 0; nx < needle.w; nx++) {
for (int ny = 0; ny < needle.h; ny++) {
int rgbA = haystack.getRGB(x + nx, y + ny);
int ra = rgbA >> 16 & 0xFF;
int ga = rgbA >> 8 & 0xFF;
int ba = rgbA >> 0 & 0xFF;
SAD += Math.abs(ra - needle.rgbCache[nx][ny][0]);
SAD += Math.abs(ga - needle.rgbCache[nx][ny][1]);
SAD += Math.abs(ba - needle.rgbCache[nx][ny][2]);
}
}
if (SAD / needle.maxDiff <= (1 - confidence)) {
Rectangle newRect = new Rectangle(x, y, needle.w, needle.h);
if (rects.size() == 0) {
rects.add(newRect);
} else {
boolean collides = false;
for (Rectangle r: rects) {
Rectangle insct = r.intersection(newRect);
if (!insct.isEmpty()) {
collides = true;
break;
}
}
if (!collides)
rects.add(newRect);
}
}
}
}
return rects;
}
private BufferedImage downSample(BufferedImage src, int ratio) {
BufferedImage resize = new BufferedImage(
src.getWidth() / ratio,
src.getHeight() / ratio,
src.getType()
);
Graphics graphics = resize.createGraphics();
graphics.drawImage(src, 0, 0, resize.getWidth(), resize.getHeight(), null);
graphics.dispose();
return resize;
}
public void drawFoundRegions(BufferedImage haystack, ArrayList<Rectangle> rects, String location) throws IOException {
for (Rectangle r : rects) {
for (int row = 0, y = r.y; y < r.y + r.getHeight() && y < haystack.getHeight(); y++, row++) {
for (int x = r.x; x < r.x + r.getWidth() && x < haystack.getWidth(); x++) {
if (row == 0 || row == r.getHeight() - 1 || x == r.x || x == r.x + r.getWidth() - 1)
haystack.setRGB(x, y, 0x00FF00);
}
}
}
File output = new File(location);
ImageIO.write(haystack, "png", output);
}
public static void main(String[] args) throws IOException {
BufferedImage haystack = ImageIO.read(new File("image.jpg"));
NeedleFind needleFinder = new NeedleFind("autohunt/assets/coal_small_2.png", 3);
long start = System.currentTimeMillis();
ArrayList<Rectangle> rects = needleFinder.find(haystack, 0.90);
System.out.printf("find took: %fsec\n", (System.currentTimeMillis() - start) / 1000.0);
needleFinder.drawFoundRegions(haystack, rects, "result.jpg");
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment