Skip to content

Instantly share code, notes, and snippets.

@kishida
Created March 8, 2023 23:45
Show Gist options
  • Star 1 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save kishida/0ac9f96cbf9f4d4f91906f74205472c8 to your computer and use it in GitHub Desktop.
Save kishida/0ac9f96cbf9f4d4f91906f74205472c8 to your computer and use it in GitHub Desktop.
Get related entiry using OpenAI embedding
package naoki.openai;
import com.mongodb.client.MongoClients;
import com.mongodb.client.model.Filters;
import com.theokanning.openai.OpenAiHttpException;
import com.theokanning.openai.embedding.EmbeddingRequest;
import com.theokanning.openai.embedding.EmbeddingResult;
import com.theokanning.openai.service.OpenAiService;
import java.io.IOException;
import java.nio.file.Files;
import java.nio.file.Path;
import java.time.Duration;
import java.util.List;
public class HatenaReader {
static class Header{ String baseName; String image; String title;
String date; boolean published;}
public record BlogEntry (
String title, String baseName, String image, String date, boolean published,
String body, String stripedBody, List<Double> vector) { }
static String getToken() {
return System.getenv("OPENAI_TOKEN");
}
static final String HATENA_DATA = "nowokay.hatenablog.com.export.txt";
public static void main(String[] args) throws IOException, InterruptedException {
var path = Path.of(HATENA_DATA);
try (var bur = Files.newBufferedReader(path);
var client = MongoClients.create("mongodb://localhost:27017"))
{
var db = client.getDatabase("blog_db");
var coll = db.getCollection("entries", BlogEntry.class);
var service = new OpenAiService(getToken(), Duration.ZERO);
enum Part {HEADER, CONTENT, BODY, COMMENT}
Part p = Part.HEADER;
int lineCount = 0;
int docCount = 0;
StringBuilder body = new StringBuilder();
StringBuilder striped = new StringBuilder();
Header h = new Header();
for (String line; (line = bur.readLine()) != null; ) {
switch (p) {
case HEADER -> {
if (line.startsWith("BASENAME")) {
System.out.println("bn:" + (h.baseName = line.substring("BASENAME: ".length())));
} else if(line.startsWith("IMAGE")) {
System.out.println("img:" + (h.image = line.substring("IMAGE: ".length())));
} else if(line.startsWith("TITLE")) {
System.out.println("title:" + (h.title = line.substring("TITLE: ".length())));
} else if(line.startsWith("DATE")) {
System.out.println("date:" + (h.date = line.substring("DATE: ".length())));
} else if(line.equals("STATUS: Publish")) {
h.published = true;
} else if (line.equals("-----")) {
p = Part.CONTENT;
}
}
case CONTENT -> {
p = Part.BODY;
body.setLength(0);
striped.setLength(0);
lineCount = 0;
++docCount;
}
case BODY -> {
if (line.equals("-----")) {
p = Part.COMMENT;
} else {
var s = line.replaceAll("<[a-z/][^>]*>", "");
striped.append(s).append("\n");
body.append(line).append("\n");
if (lineCount++ < 3) {
System.out.println(s);
}
}
}
case COMMENT -> {
if (h.baseName == null || h.date == null || h.title == null) {
System.out.println("!!");
return;
}
if (line.equals("--------")) {
if (coll.find(Filters.eq("baseName", h.baseName)).first() == null) {
var text = striped.toString();
var req = EmbeddingRequest.builder()
.user("dummy")
.model("text-embedding-ada-002")
.input(List.of(text.substring(0, Math.min(text.length(), 4000)))).build();
EmbeddingResult res = null;
for (int i = 0; i < 5; ++i) {
try {
res = service.createEmbeddings(req);
} catch (OpenAiHttpException ex) {
System.out.println(ex.getMessage());
Thread.sleep(Duration.ofMinutes(1));
continue;
}
break;
}
if (res == null) {
System.out.println("retry 5 times but could not access");
return;
}
BlogEntry ent = new BlogEntry(
h.title, h.baseName, h.image, h.date, h.published,
body.toString(), text, res.getData().get(0).getEmbedding());
coll.insertOne(ent);
System.out.println(ent.vector);
System.out.println("---");
Thread.sleep(Duration.ofSeconds(3).plusMillis(100)); // 20 request per min for the rate limit
}
p = Part.HEADER;
h = new Header();
}
}
}
}
System.out.println(docCount);
}
}
}
<?xml version="1.0" encoding="UTF-8"?>
<project xmlns="http://maven.apache.org/POM/4.0.0" xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance" xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
<modelVersion>4.0.0</modelVersion>
<groupId>naoki</groupId>
<artifactId>relatedblog</artifactId>
<version>1.0</version>
<packaging>jar</packaging>
<dependencies>
<dependency>
<groupId>com.theokanning.openai-gpt3-java</groupId>
<artifactId>service</artifactId>
<version>0.10.0</version>
</dependency>
<dependency>
<groupId>org.mongodb</groupId>
<artifactId>mongodb-driver-sync</artifactId>
<version>4.9.0</version>
</dependency>
</dependencies>
<properties>
<project.build.sourceEncoding>UTF-8</project.build.sourceEncoding>
<maven.compiler.source>19</maven.compiler.source>
<maven.compiler.target>19</maven.compiler.target>
<exec.mainClass>naoki.relatedblog</exec.mainClass>
</properties>
</project>
package naoki.openai;
import com.mongodb.client.MongoClients;
import com.mongodb.client.model.Filters;
import com.theokanning.openai.embedding.EmbeddingRequest;
import com.theokanning.openai.embedding.EmbeddingResult;
import com.theokanning.openai.service.OpenAiService;
import java.io.BufferedReader;
import java.io.IOException;
import java.io.InputStream;
import java.io.InputStreamReader;
import java.io.OutputStream;
import java.io.PrintWriter;
import java.net.ServerSocket;
import java.net.Socket;
import java.net.URLDecoder;
import java.time.Duration;
import java.util.List;
import java.util.Set;
import java.util.TreeSet;
import java.util.function.Predicate;
import java.util.stream.StreamSupport;
public class RelatedBlog {
public record BlogEntry (
String title, String baseName, String image, String date, boolean published,
String body, String stripedBody, List<Double> vector) {}
public record Keyword (String word, List<Double> vector) {}
private static String limitText(String text, int lengh) {
return text.substring(0, Math.min(text.length(), lengh));
}
private static void printEntry(PrintWriter pw, BlogEntry ent) {
pw.println(
"""
<div style="margin: 10px 5px">
<a href="https://nowokay.hatenablog.com/entry/%s">%s</a>
<div style="margin: 5px 15px">
<div style="font-size: small; width: 600px">%s</div>
<a href="/%1$s">近いエントリを探す</a>
""".formatted(ent.baseName(), ent.title(),
limitText(ent.stripedBody(), 100)));
if (ent.image() != null) {
pw.println("""
<div><img src="%s" style="width: auto; height: 200px"></div>
""".formatted(ent.image()));
}
pw.println("</div></div>");
}
private static void printRelated(PrintWriter pw, List<BlogEntry> entries, List<Double> vector) {
record Score(double score, BlogEntry entry) {}
TreeSet<Score> ts = new TreeSet<>(
(s1, s2) -> Double.compare(s1.score(), s2.score()));
for (var e : entries) {
if (!e.published() || e.stripedBody().length() < 50) {
continue;
}
double score = 0;
for (int i = 0; i < vector.size(); ++i) {
score += vector.get(i) * e.vector().get(i);
}
ts.add(new Score(-score, e));
while (ts.size() > 4) ts.remove(ts.last());
}
ts.stream().skip(1).forEach(sc -> { // 最初の一件は同じエントリ
printEntry(pw, sc.entry());
pw.println("score: %f".formatted(sc.score()));
});
}
public static void main(String[] args) throws IOException {
try (var client = MongoClients.create("mongodb://localhost:27017")) {
var db = client.getDatabase("blog_db");
var entColl = db.getCollection("entries", BlogEntry.class);
var keyColl = db.getCollection("keywords", Keyword.class);
List<BlogEntry> entries = StreamSupport.stream(entColl.find().spliterator(), false).toList();
var service = new OpenAiService(System.getenv("OPENAI_TOKEN"), Duration.ZERO);
ServerSocket serverSoc = new ServerSocket(8989);
for (;;) {
try (Socket s = serverSoc.accept();
InputStream is = s.getInputStream();
BufferedReader bur = new BufferedReader(new InputStreamReader(is));
OutputStream os = s.getOutputStream();
PrintWriter pw = new PrintWriter(os))
{
String firstLine = bur.readLine();
String query = firstLine == null ? "" : firstLine.split(" ")[1].substring(1);
bur.lines().takeWhile(Predicate.not(String::isEmpty)).count();
pw.println("HTTP/1.0 200 OK");
pw.println("Content-Type: text/html; charset=utf-8");
pw.println();
var header = """
<title>近いブログエントリを探す%s</title>
<h1>近いブログエントリを探す%<s</h1>
<form method="get" action="/">
<input type="text" name="q" value="%s" style="width: 300px">
<input type="submit" value="検索">
</form>
""";
if (query.contains("?")) {
var q = URLDecoder.decode(query.substring(query.indexOf('?') + 3), "utf-8");
pw.println(header.formatted("", q));
var word = keyColl.find(Filters.eq("word", q)).first();
List<Double> vec;
if (word == null) {
var req = EmbeddingRequest.builder()
.user("dummy")
.model("text-embedding-ada-002")
.input(List.of(limitText(q, 4000))).build();
EmbeddingResult res = service.createEmbeddings(req);
vec = res.getData().get(0).getEmbedding();
keyColl.insertOne(new Keyword(q, vec));
} else {
vec = word.vector();
}
printRelated(pw, entries, vec);
continue;
}
var entry = query.isEmpty() ? null :
entColl.find(Filters.eq("baseName", query)).first();
if (entry == null) {
pw.println(header.formatted("", ""));
entries.stream().limit(3).forEach(ent -> printEntry(pw, ent));
} else {
pw.println(header.formatted(" - " + limitText(entry.title(), 15), ""));
printEntry(pw, entry);
pw.println("<hr><h2>近いエントリ</h2>");
printRelated(pw, entries, entry.vector());
}
}
}
}
}
}
@kishida
Copy link
Author

kishida commented Mar 8, 2023

image

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment