Created
February 22, 2024 22:03
-
-
Save rsiemens/45351f09b5e447fffe022908a5e6c111 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
package autohunt; | |
import java.util.ArrayList; | |
import java.util.HashMap; | |
import java.io.File; | |
import java.io.IOException; | |
import java.awt.Rectangle; | |
import javax.imageio.ImageIO; | |
import java.awt.image.BufferedImage; | |
import java.awt.Graphics; | |
class NeedleFind { | |
class NeedleImage { | |
public final BufferedImage img; | |
public final int[][][] rgbCache; | |
public final int maxDiff; | |
public final int w; | |
public final int h; | |
public NeedleImage(BufferedImage img) { | |
this.img = img; | |
w = img.getWidth(); | |
h = img.getHeight(); | |
maxDiff = w * h * 255 * 3; | |
rgbCache = new int[w][h][3]; | |
for (int x = 0; x < w; x++) { | |
for (int y = 0; y < h; y++) { | |
int rgb = img.getRGB(x, y); | |
rgbCache[x][y][0] = rgb >> 16 & 0xFF; | |
rgbCache[x][y][1] = rgb >> 8 & 0xFF; | |
rgbCache[x][y][2] = rgb >> 0 & 0xFF; | |
} | |
} | |
} | |
} | |
NeedleImage[] needlePyramid; | |
int levels; | |
public NeedleFind(String needlePath, int levels) throws IOException { | |
BufferedImage needle = ImageIO.read(new File(needlePath)); | |
needlePyramid = new NeedleImage[levels]; | |
this.levels = levels; | |
for (int i = 0; i < levels; i++) { | |
if (i == 0) { | |
needlePyramid[i] = new NeedleImage(needle); | |
} else { | |
needle = downSample(needlePyramid[i - 1].img, 2); | |
needlePyramid[i] = new NeedleImage(needle); | |
} | |
} | |
} | |
public ArrayList<Rectangle> find(BufferedImage haystack, double confidence) { | |
BufferedImage[] haystackPyramid = new BufferedImage[levels]; | |
haystackPyramid[0] = haystack; | |
for (int i = 1; i < levels; i++) { | |
haystackPyramid[i] = downSample(haystackPyramid[i - 1], 2); | |
} | |
int x = 0; | |
int y = 0; | |
int w = haystackPyramid[levels - 1].getWidth(); | |
int h = haystackPyramid[levels - 1].getHeight(); | |
// start at the lowest resolution and match down | |
ArrayList<Rectangle> found = find(haystackPyramid[levels - 1], needlePyramid[levels - 1], x, y, w, h, confidence); | |
ArrayList<Rectangle> nextFound = new ArrayList<>(); | |
for (int i = levels - 2; i >= 0; i--) { | |
for (Rectangle r: found) { | |
x = (int)Math.floor(r.x * 2 - 0.1 * needlePyramid[i].w); | |
y = (int)Math.floor(r.y * 2 - 0.1 * needlePyramid[i].h); | |
w = (int)Math.ceil(r.x * 2 + 1.1 * needlePyramid[i].w); | |
h = (int)Math.ceil(r.y * 2 + 1.1 * needlePyramid[i].h); | |
nextFound.addAll(find(haystackPyramid[i], needlePyramid[i], x, y, w, h, confidence)); | |
} | |
found = nextFound; | |
nextFound = new ArrayList<>(); | |
} | |
return found; | |
} | |
private ArrayList<Rectangle> find( | |
BufferedImage haystack, | |
NeedleImage needle, | |
int xs, | |
int ys, | |
int w, | |
int h, | |
double confidence | |
) { | |
ArrayList<Rectangle> rects = new ArrayList<>(); | |
xs = Math.max(xs, 0); | |
ys = Math.max(ys, 0); | |
w = Math.min(w, haystack.getWidth()); | |
h = Math.min(h, haystack.getHeight()); | |
for (int x = xs; x < w - needle.w; x++) { | |
for (int y = ys; y < h - needle.h; y++) { | |
double SAD = 0.0; | |
for (int nx = 0; nx < needle.w; nx++) { | |
for (int ny = 0; ny < needle.h; ny++) { | |
int rgbA = haystack.getRGB(x + nx, y + ny); | |
int ra = rgbA >> 16 & 0xFF; | |
int ga = rgbA >> 8 & 0xFF; | |
int ba = rgbA >> 0 & 0xFF; | |
SAD += Math.abs(ra - needle.rgbCache[nx][ny][0]); | |
SAD += Math.abs(ga - needle.rgbCache[nx][ny][1]); | |
SAD += Math.abs(ba - needle.rgbCache[nx][ny][2]); | |
} | |
} | |
if (SAD / needle.maxDiff <= (1 - confidence)) { | |
Rectangle newRect = new Rectangle(x, y, needle.w, needle.h); | |
if (rects.size() == 0) { | |
rects.add(newRect); | |
} else { | |
boolean collides = false; | |
for (Rectangle r: rects) { | |
Rectangle insct = r.intersection(newRect); | |
if (!insct.isEmpty()) { | |
collides = true; | |
break; | |
} | |
} | |
if (!collides) | |
rects.add(newRect); | |
} | |
} | |
} | |
} | |
return rects; | |
} | |
private BufferedImage downSample(BufferedImage src, int ratio) { | |
BufferedImage resize = new BufferedImage( | |
src.getWidth() / ratio, | |
src.getHeight() / ratio, | |
src.getType() | |
); | |
Graphics graphics = resize.createGraphics(); | |
graphics.drawImage(src, 0, 0, resize.getWidth(), resize.getHeight(), null); | |
graphics.dispose(); | |
return resize; | |
} | |
public void drawFoundRegions(BufferedImage haystack, ArrayList<Rectangle> rects, String location) throws IOException { | |
for (Rectangle r : rects) { | |
for (int row = 0, y = r.y; y < r.y + r.getHeight() && y < haystack.getHeight(); y++, row++) { | |
for (int x = r.x; x < r.x + r.getWidth() && x < haystack.getWidth(); x++) { | |
if (row == 0 || row == r.getHeight() - 1 || x == r.x || x == r.x + r.getWidth() - 1) | |
haystack.setRGB(x, y, 0x00FF00); | |
} | |
} | |
} | |
File output = new File(location); | |
ImageIO.write(haystack, "png", output); | |
} | |
public static void main(String[] args) throws IOException { | |
BufferedImage haystack = ImageIO.read(new File("image.jpg")); | |
NeedleFind needleFinder = new NeedleFind("autohunt/assets/coal_small_2.png", 3); | |
long start = System.currentTimeMillis(); | |
ArrayList<Rectangle> rects = needleFinder.find(haystack, 0.90); | |
System.out.printf("find took: %fsec\n", (System.currentTimeMillis() - start) / 1000.0); | |
needleFinder.drawFoundRegions(haystack, rects, "result.jpg"); | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment