Skip to content

Instantly share code, notes, and snippets.

@kishida
Last active June 10, 2024 19:58
Show Gist options
  • Save kishida/6c66b3c212f432a19aa176859163e93c to your computer and use it in GitHub Desktop.
Save kishida/6c66b3c212f432a19aa176859163e93c to your computer and use it in GitHub Desktop.
CLIPを使った画像検索
from fastapi import FastAPI
from pydantic import BaseModel
import io
import requests
from PIL import Image
import torch
from transformers import AutoImageProcessor, AutoModel, AutoTokenizer
device = "cuda" if torch.cuda.is_available() else "cpu"
HF_MODEL_PATH = 'line-corporation/clip-japanese-base'
tokenizer = AutoTokenizer.from_pretrained(HF_MODEL_PATH, trust_remote_code=True)
processor = AutoImageProcessor.from_pretrained(HF_MODEL_PATH, trust_remote_code=True)
model = AutoModel.from_pretrained(HF_MODEL_PATH, trust_remote_code=True).to(device)
app = FastAPI()
class TextRequest(BaseModel):
text: str
@app.post("/text_embed")
def embed_text(request: TextRequest):
text_t = tokenizer(request.text).to(device)
with torch.no_grad():
text_features = model.get_text_features(**text_t)
embedding = text_features.cpu().numpy().tolist()
return {"embedding": embedding[0]}
@app.post("/image_embed")
def embed_text(request: TextRequest):
image = Image.open(request.text)
image_t = processor(image, return_tensors="pt").to(device)
with torch.no_grad():
image_features = model.get_image_features(**image_t)
embedding = image_features.cpu().numpy().tolist()
return {"embedding": embedding[0]}
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=8000)
package imagesearchwithclip;
import com.fasterxml.jackson.databind.ObjectMapper;
import java.io.IOException;
import java.io.UncheckedIOException;
import java.net.URI;
import java.net.http.HttpClient;
import java.net.http.HttpRequest;
import java.net.http.HttpResponse;
import java.util.Arrays;
public class ClipClient {
private static final String BASE_URL = "http://localhost:8000/";
private record TextRequest(String text){}
private record EmbedResponse(double[] embedding){}
private static ObjectMapper mapper = new ObjectMapper();
private static HttpClient client = HttpClient.newBuilder()
.version(HttpClient.Version.HTTP_1_1) // FastAPI doesn't work with h2 upgrading
.build();
static double[] textEmbedding(String text) {
return embedRequest("text_embed", text);
}
static double[] imageEmbedding(String text) {
return embedRequest("image_embed", text);
}
private static double[] embedRequest(String endPoint, String text) {
try {
String json = mapper.writeValueAsString(new TextRequest(text));
var req = HttpRequest.newBuilder()
.uri(URI.create(BASE_URL + endPoint))
.header("Content-Type", "application/json")
.POST(HttpRequest.BodyPublishers.ofString(json))
.build();
var res = client.send(req, HttpResponse.BodyHandlers.ofString());
if (res.statusCode() != 200) {
throw new RuntimeException(res.body());
}
var body = mapper.readValue(res.body(), EmbedResponse.class);
return body.embedding();
} catch (IOException ex) {
throw new UncheckedIOException(ex);
} catch (InterruptedException ex) {
throw new RuntimeException(ex);
}
}
}
package imagesearchwithclip;
import com.fasterxml.jackson.databind.ObjectMapper;
import imagesearchwithclip.Sercher.ImageData;
import java.io.IOException;
import java.nio.file.Files;
import java.nio.file.Path;
public class CreateIndex {
static String PATH = "C:\\Users\\naoki\\Desktop\\sample_image";
public static void main(String[] args) throws IOException {
var images = Files.list(Path.of(PATH))
.filter(p -> !Files.isDirectory(p))
.map(p -> {
System.out.println(p);
try {
return new ImageData(p, ClipClient.imageEmbedding(p.toString()));
} catch (Exception e) {
System.out.println(e);
return null;
}
})
.toArray(ImageData[]::new);
ObjectMapper mapper = new ObjectMapper();
mapper.writeValue(Files.newOutputStream(Path.of("index.json")), images);
}
}
package imagesearchwithclip;
import com.drew.imaging.ImageMetadataReader;
import com.drew.metadata.exif.ExifIFD0Directory;
import java.awt.BorderLayout;
import java.awt.Color;
import java.awt.Font;
import java.awt.Image;
import java.awt.RenderingHints;
import java.awt.geom.AffineTransform;
import java.awt.image.AffineTransformOp;
import java.awt.image.BufferedImage;
import java.io.IOException;
import java.nio.file.Files;
import java.nio.file.Path;
import javax.imageio.ImageIO;
import javax.swing.ImageIcon;
import javax.swing.JFrame;
import javax.swing.JLabel;
import javax.swing.JTextField;
public class DrawUI {
public static void main(String[] args) {
var file = "photos.json";
Sercher sh = new Sercher(file, 4);
var f = new JFrame("画像検索 with CLIP");
var input = new JTextField();
f.add(BorderLayout.NORTH, input);
var area = new JLabel();
f.add(area);
f.setDefaultCloseOperation(JFrame.EXIT_ON_CLOSE);
final int WIDTH = 750;
final int HEIGHT = 550;
f.setSize(WIDTH, HEIGHT);
f.setVisible(true);
input.addActionListener(ae -> {
var text = input.getText();
long start = System.currentTimeMillis();
var results = sh.search(text);
System.out.println("search fin in %dms".formatted(System.currentTimeMillis() - start));
var image = new BufferedImage(WIDTH, HEIGHT, BufferedImage.TYPE_INT_RGB);
var g = image.createGraphics();
g.setRenderingHint(RenderingHints.KEY_TEXT_ANTIALIASING, RenderingHints.VALUE_TEXT_ANTIALIAS_ON);
g.setRenderingHint(RenderingHints.KEY_ANTIALIASING, RenderingHints.VALUE_ANTIALIAS_ON);
g.setColor(Color.WHITE);
g.fillRect(0, 0, WIDTH, HEIGHT);
g.setColor(Color.BLACK);
g.setFont(new Font(Font.SANS_SERIF, Font.BOLD, 18));
g.drawString(text + "の検索結果", 15, 45);
g.setFont(new Font(Font.SANS_SERIF, Font.PLAIN, 14));
int offset = 20;
int vpos = 70;
for (var r : results) {
System.out.println("draw " + r.image().path());
g.drawString("%f".formatted(r.score()), offset, vpos);
try {
var img = rotatedImage(r.image().path());
g.drawImage(img, offset, vpos + 12, null);
} catch (IOException ex) {
System.out.println("read error:" + r.image().path());
}
offset += 250;
if (offset > WIDTH - 240) {
offset = 20;
vpos += 250;
}
}
area.setIcon(new ImageIcon(image));
input.setText("");
});
}
static Image rotatedImage(Path path) throws IOException{
var img = ImageIO.read(Files.newInputStream(path));
if (img == null) {
throw new IOException("cant read image:" + path);
}
int w = img.getWidth();
int h = img.getHeight();
int limit = 240;
double scale = 1;
int scaledW = w;
int scaledH = h;
if (w > limit || h > limit) {
if (w > h) {
scale = limit / (double)w;
scaledW = limit;
scaledH = (int)(h * scale);
} else {
scale = limit / (double)h;
scaledW = (int)(w * scale);
scaledH = limit;
}
}
int ori = 1;
try {
var metadata = ImageMetadataReader.readMetadata(Files.newInputStream(path));
var dir = metadata.getFirstDirectoryOfType(ExifIFD0Directory.class);
ori = dir.getInt(ExifIFD0Directory.TAG_ORIENTATION);
} catch(Exception ex) {}
var trans = new AffineTransform();
switch (ori) {
case 6 -> {//右
trans.translate(scaledH, 0);
trans.rotate(Math.toRadians(90));
}
case 3 -> {//逆
trans.translate(scaledW, scaledH);
trans.rotate(Math.toRadians(180));
}
case 8 -> {//左
trans.translate(0, scaledW);
trans.rotate(Math.toRadians(270));
}
}
switch (ori) {
case 6, 8 -> {
int t = scaledH;
scaledH = scaledW;
scaledW = t;
}
}
trans.scale(scale, scale);
var rotated = new BufferedImage(scaledW, scaledH, img.getType());
var op = new AffineTransformOp(trans, AffineTransformOp.TYPE_BICUBIC);
op.filter(img, rotated);
return rotated;
}
}
<?xml version="1.0" encoding="UTF-8"?>
<project xmlns="http://maven.apache.org/POM/4.0.0" xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance" xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
<modelVersion>4.0.0</modelVersion>
<groupId>com.mycompany</groupId>
<artifactId>ImageSearchWithClip</artifactId>
<version>1.0-SNAPSHOT</version>
<packaging>jar</packaging>
<dependencies>
<dependency>
<groupId>com.fasterxml.jackson.core</groupId>
<artifactId>jackson-databind</artifactId>
<version>2.17.0</version>
</dependency>
<dependency>
<groupId>com.drewnoakes</groupId>
<artifactId>metadata-extractor</artifactId>
<version>2.19.0</version>
</dependency>
</dependencies>
<properties>
<project.build.sourceEncoding>UTF-8</project.build.sourceEncoding>
<maven.compiler.source>22</maven.compiler.source>
<maven.compiler.target>22</maven.compiler.target>
<exec.mainClass>imagesearchwithclip.ImageSearchWithClip</exec.mainClass>
</properties>
</project>
package imagesearchwithclip;
import com.fasterxml.jackson.databind.ObjectMapper;
import java.io.IOException;
import java.io.UncheckedIOException;
import java.nio.file.Files;
import java.nio.file.Path;
import java.util.Arrays;
public class Sercher {
record ImageData(Path path, double[] embedding){}
record Result(ImageData image, double score){}
private static double prod(double[] a, double[] b) {
double score = 0;
for(int i = 0; i < a.length; ++i) {
score += a[i] * b[i];
}
return score;
}
private ImageData[] images;
public Sercher(String path) {
try {
ObjectMapper mapper = new ObjectMapper();
images = mapper.readValue(Files.newInputStream(Path.of("index.json")), ImageData[].class);
} catch (IOException ex) {
throw new UncheckedIOException(ex);
}
}
private static final int TOP_COUNT = 3;
Result[] search(String text) {
var emb = ClipClient.textEmbedding(text);
System.out.println(Arrays.toString(emb));
var top5 = new Result[TOP_COUNT + 1];
for (var img : images) {
var score = prod(emb, img.embedding());
for (int i = top5.length - 2; i >= 0; --i) {
if (top5[i] == null) {
if (i == 0 || top5[i - 1] != null) {
top5[i] = new Result(img, score);
}
continue;
}
if (top5[i].score() < score) {
top5[i + 1] = top5[i];
top5[i] = new Result(img, score);
}
}
System.out.println("%s: %f".formatted(img.path(), score));
}
return Arrays.stream(top5).limit(top5.length - 1).toArray(Result[]::new);
}
}
@kishida
Copy link
Author

kishida commented Jun 7, 2024

bandicam.2024-06-08.07-29-26-469.mp4

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment