Created
March 8, 2023 23:45
-
-
Save kishida/0ac9f96cbf9f4d4f91906f74205472c8 to your computer and use it in GitHub Desktop.
Get related entiry using OpenAI embedding
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
package naoki.openai; | |
import com.mongodb.client.MongoClients; | |
import com.mongodb.client.model.Filters; | |
import com.theokanning.openai.OpenAiHttpException; | |
import com.theokanning.openai.embedding.EmbeddingRequest; | |
import com.theokanning.openai.embedding.EmbeddingResult; | |
import com.theokanning.openai.service.OpenAiService; | |
import java.io.IOException; | |
import java.nio.file.Files; | |
import java.nio.file.Path; | |
import java.time.Duration; | |
import java.util.List; | |
public class HatenaReader { | |
static class Header{ String baseName; String image; String title; | |
String date; boolean published;} | |
public record BlogEntry ( | |
String title, String baseName, String image, String date, boolean published, | |
String body, String stripedBody, List<Double> vector) { } | |
static String getToken() { | |
return System.getenv("OPENAI_TOKEN"); | |
} | |
static final String HATENA_DATA = "nowokay.hatenablog.com.export.txt"; | |
public static void main(String[] args) throws IOException, InterruptedException { | |
var path = Path.of(HATENA_DATA); | |
try (var bur = Files.newBufferedReader(path); | |
var client = MongoClients.create("mongodb://localhost:27017")) | |
{ | |
var db = client.getDatabase("blog_db"); | |
var coll = db.getCollection("entries", BlogEntry.class); | |
var service = new OpenAiService(getToken(), Duration.ZERO); | |
enum Part {HEADER, CONTENT, BODY, COMMENT} | |
Part p = Part.HEADER; | |
int lineCount = 0; | |
int docCount = 0; | |
StringBuilder body = new StringBuilder(); | |
StringBuilder striped = new StringBuilder(); | |
Header h = new Header(); | |
for (String line; (line = bur.readLine()) != null; ) { | |
switch (p) { | |
case HEADER -> { | |
if (line.startsWith("BASENAME")) { | |
System.out.println("bn:" + (h.baseName = line.substring("BASENAME: ".length()))); | |
} else if(line.startsWith("IMAGE")) { | |
System.out.println("img:" + (h.image = line.substring("IMAGE: ".length()))); | |
} else if(line.startsWith("TITLE")) { | |
System.out.println("title:" + (h.title = line.substring("TITLE: ".length()))); | |
} else if(line.startsWith("DATE")) { | |
System.out.println("date:" + (h.date = line.substring("DATE: ".length()))); | |
} else if(line.equals("STATUS: Publish")) { | |
h.published = true; | |
} else if (line.equals("-----")) { | |
p = Part.CONTENT; | |
} | |
} | |
case CONTENT -> { | |
p = Part.BODY; | |
body.setLength(0); | |
striped.setLength(0); | |
lineCount = 0; | |
++docCount; | |
} | |
case BODY -> { | |
if (line.equals("-----")) { | |
p = Part.COMMENT; | |
} else { | |
var s = line.replaceAll("<[a-z/][^>]*>", ""); | |
striped.append(s).append("\n"); | |
body.append(line).append("\n"); | |
if (lineCount++ < 3) { | |
System.out.println(s); | |
} | |
} | |
} | |
case COMMENT -> { | |
if (h.baseName == null || h.date == null || h.title == null) { | |
System.out.println("!!"); | |
return; | |
} | |
if (line.equals("--------")) { | |
if (coll.find(Filters.eq("baseName", h.baseName)).first() == null) { | |
var text = striped.toString(); | |
var req = EmbeddingRequest.builder() | |
.user("dummy") | |
.model("text-embedding-ada-002") | |
.input(List.of(text.substring(0, Math.min(text.length(), 4000)))).build(); | |
EmbeddingResult res = null; | |
for (int i = 0; i < 5; ++i) { | |
try { | |
res = service.createEmbeddings(req); | |
} catch (OpenAiHttpException ex) { | |
System.out.println(ex.getMessage()); | |
Thread.sleep(Duration.ofMinutes(1)); | |
continue; | |
} | |
break; | |
} | |
if (res == null) { | |
System.out.println("retry 5 times but could not access"); | |
return; | |
} | |
BlogEntry ent = new BlogEntry( | |
h.title, h.baseName, h.image, h.date, h.published, | |
body.toString(), text, res.getData().get(0).getEmbedding()); | |
coll.insertOne(ent); | |
System.out.println(ent.vector); | |
System.out.println("---"); | |
Thread.sleep(Duration.ofSeconds(3).plusMillis(100)); // 20 request per min for the rate limit | |
} | |
p = Part.HEADER; | |
h = new Header(); | |
} | |
} | |
} | |
} | |
System.out.println(docCount); | |
} | |
} | |
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
<?xml version="1.0" encoding="UTF-8"?> | |
<project xmlns="http://maven.apache.org/POM/4.0.0" xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance" xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd"> | |
<modelVersion>4.0.0</modelVersion> | |
<groupId>naoki</groupId> | |
<artifactId>relatedblog</artifactId> | |
<version>1.0</version> | |
<packaging>jar</packaging> | |
<dependencies> | |
<dependency> | |
<groupId>com.theokanning.openai-gpt3-java</groupId> | |
<artifactId>service</artifactId> | |
<version>0.10.0</version> | |
</dependency> | |
<dependency> | |
<groupId>org.mongodb</groupId> | |
<artifactId>mongodb-driver-sync</artifactId> | |
<version>4.9.0</version> | |
</dependency> | |
</dependencies> | |
<properties> | |
<project.build.sourceEncoding>UTF-8</project.build.sourceEncoding> | |
<maven.compiler.source>19</maven.compiler.source> | |
<maven.compiler.target>19</maven.compiler.target> | |
<exec.mainClass>naoki.relatedblog</exec.mainClass> | |
</properties> | |
</project> |
Author
kishida
commented
Mar 8, 2023
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment