Created
April 12, 2023 20:17
-
-
Save kishida/b7c76b650f67eee53e04e36496fa83d3 to your computer and use it in GitHub Desktop.
「おしえてきしださん」bot
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
// require indexing here | |
// https://gist.github.com/kishida/0ac9f96cbf9f4d4f91906f74205472c8 | |
package naoki.openai; | |
import javax.swing.*; | |
import java.awt.BorderLayout; | |
import java.time.Duration; | |
import java.util.List; | |
import java.util.stream.StreamSupport; | |
import com.mongodb.client.MongoClients; | |
import com.theokanning.openai.completion.chat.ChatCompletionRequest; | |
import com.theokanning.openai.completion.chat.ChatCompletionResult; | |
import com.theokanning.openai.completion.chat.ChatMessage; | |
import com.theokanning.openai.completion.chat.ChatMessageRole; | |
import com.theokanning.openai.embedding.EmbeddingRequest; | |
import com.theokanning.openai.embedding.EmbeddingResult; | |
import com.theokanning.openai.service.OpenAiService; | |
public class HatenaSearch { | |
public record BlogEntry ( | |
String title, String baseName, String image, String date, boolean published, | |
String body, String stripedBody, List<Double> vector) {} | |
static JTextArea outputText; | |
static JTextField searchText; | |
static List<BlogEntry> entries; | |
static OpenAiService service; | |
public static void main(String[] args) { | |
var f = new JFrame("教えてきしださん"); | |
f.setDefaultCloseOperation(JFrame.EXIT_ON_CLOSE); | |
var p = new JPanel(); | |
searchText = new JTextField(40); | |
var searchButton = new JButton("きしだに聞く"); | |
searchButton.addActionListener(e -> search()); | |
p.add(searchText); | |
p.add(searchButton); | |
f.add(p, BorderLayout.NORTH); | |
outputText = new JTextArea(); | |
outputText.setLineWrap(true); | |
f.add(new JScrollPane(outputText), BorderLayout.CENTER); | |
f.setSize(800, 600); | |
f.setVisible(true); | |
var client = MongoClients.create("mongodb://localhost:27017"); | |
var db = client.getDatabase("blog_db"); | |
var entColl = db.getCollection("entries", BlogEntry.class); | |
entries = StreamSupport.stream(entColl.find().spliterator(), false).toList(); | |
service = new OpenAiService(System.getenv("OPENAI_TOKEN"), Duration.ZERO); | |
} | |
static void search() { | |
var text = searchText.getText(); | |
var req = EmbeddingRequest.builder() | |
.user("dummy") | |
.model("text-embedding-ada-002") | |
.input(List.of(limitText(text, 4000))).build(); | |
EmbeddingResult res = service.createEmbeddings(req); | |
List<Double> vec = res.getData().get(0).getEmbedding(); | |
BlogEntry entry = findRelatedEntries(vec); | |
System.out.println(entry.title()); | |
// ただしユーザーは文章の存在を知らず、見ることもできません。 | |
//「記事によると」など記事を参照するような言葉はつけないでください。 | |
ChatCompletionRequest chatReq = ChatCompletionRequest.builder() | |
.user("dummy") | |
.model("gpt-3.5-turbo") | |
.messages(List.of( | |
new ChatMessage(ChatMessageRole.SYSTEM.value(), | |
""" | |
次の文章を要約して、文章の執筆者の気持ちでユーザーの質問に答えてください。 | |
文章に該当する情報がない場合は、\ | |
「該当する文章がありません。質問に答えることができません」としてください。 | |
文章: | |
%s | |
""".formatted(limitText(entry.body(), 3000))), | |
new ChatMessage(ChatMessageRole.USER.value(), text))) | |
.build(); | |
ChatCompletionResult result = service.createChatCompletion(chatReq); | |
outputText.append(""" | |
質問「%s」 | |
%s | |
via: | |
%s | |
https://nowokay.hatenablog.com/entry/%s | |
""".formatted(text, | |
result.getChoices().get(0).getMessage().getContent() | |
, entry.title(), entry.baseName())); | |
} | |
static String limitText(String text, int lengh) { | |
return text.substring(0, Math.min(text.length(), lengh)); | |
} | |
static BlogEntry findRelatedEntries(List<Double> vector) { | |
double maxScore = 0; | |
BlogEntry entry = null; | |
for (var e : entries) { | |
if (!e.published() || e.stripedBody().length() < 50) { | |
continue; | |
} | |
double score = 0; | |
for (int pos = 0; pos < vector.size(); pos++) { | |
score += vector.get(pos) * e.vector().get(pos); | |
} | |
if (entry == null || maxScore < score) { | |
entry = e; | |
maxScore = score; | |
} | |
} | |
return entry; | |
} | |
} |
Author
kishida
commented
Apr 12, 2023
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment