Skip to content

Instantly share code, notes, and snippets.

@kishida
Created April 24, 2023 10:54
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save kishida/ce97494f98871f0f84c8f6c6b4553886 to your computer and use it in GitHub Desktop.
Save kishida/ce97494f98871f0f84c8f6c6b4553886 to your computer and use it in GitHub Desktop.
グラフ説明するくん
package naoki.slm;
import javax.imageio.ImageIO;
import javax.swing.*;
import java.awt.*;
import java.awt.datatransfer.DataFlavor;
import java.awt.dnd.DnDConstants;
import java.awt.dnd.DropTarget;
import java.awt.dnd.DropTargetDropEvent;
import java.awt.image.BufferedImage;
import java.io.*;
import java.util.List;
public class ChartDescribeForm {
static File selectedFile;
public static void main(String[] args) throws IOException {
// フォームの作成
var frame = new JFrame("グラフの説明");
frame.setDefaultCloseOperation(JFrame.EXIT_ON_CLOSE);
frame.setSize(800, 600);
frame.setLayout(new GridBagLayout());
var input = new JTextArea();
Font font = new Font(Font.SANS_SERIF, Font.PLAIN, 20);
input.setFont(font);
input.setLineWrap(true);
frame.add(new JScrollPane(input), new ConstraintsBuilder()
.gridx(0).gridy(0).fill(GridBagConstraints.BOTH)
.weightx(0.7).weighty(0.2).build());
var button = new JButton("グラフを説明");
frame.add(button, new ConstraintsBuilder()
.gridx(1).gridy(0).fill(GridBagConstraints.HORIZONTAL)
.weightx(0.3).weighty(0.2).build());
// グラフを表示するラベルの作成
var label = new JLabel("drop image here");
label.setDropTarget(new ImageDropTarget(label));
frame.add(label, new ConstraintsBuilder()
.gridx(0).gridy(1).gridwidth(2).fill(GridBagConstraints.BOTH)
.weightx(1.0).weighty(0.4).build());
var output = new JTextArea();
output.setFont(font);
output.setLineWrap(true);
frame.add(new JScrollPane(output), new ConstraintsBuilder()
.gridx(0).gridy(2).gridwidth(2).fill(GridBagConstraints.BOTH)
.weightx(1.0).weighty(0.3).build());
button.addActionListener(e -> {
var text = input.getText();
if (text.isEmpty()) {
return;
}
output.append("Question: " + text + "\n");
var result = describeGraph(selectedFile, text);
output.append("Answer: " + result + "\n");
});
// pythonの実行
ProcessBuilder pb = new ProcessBuilder(PYTHON_PATH, SCRIPT_PATH);
// 出力を受け取る
pb.redirectErrorStream(false);
Process p = pb.start();
// 入力を渡す
OutputStream os = p.getOutputStream();
pyIn = new PrintWriter(os);
// 出力を受け取る
InputStream is = p.getInputStream();
pyOut = new BufferedReader(new InputStreamReader(is));
frame.setVisible(true);
}
static final String PYTHON_PATH = "PATH to python";
static final String SCRIPT_PATH = "PATH to chartdescriber.py";
static PrintWriter pyIn;
static BufferedReader pyOut;
static String describeGraph(File path, String question) {
if (path == null) {
return "no image";
}
if (question.isEmpty()) {
return "no question";
}
try {
// send the image path
pyIn.println(path.getAbsolutePath());
// send the question
pyIn.println(question);
pyIn.flush();
// output the result
return pyOut.readLine();
} catch (Exception e) {
e.printStackTrace();
}
return "not implemented yet";
}
static class ImageDropTarget extends DropTarget {
JLabel label;
public ImageDropTarget(JLabel label) {
this.label = label;
}
@Override
public void drop(DropTargetDropEvent dtde) {
try {
// ドロップされたファイルを取得
dtde.acceptDrop(DnDConstants.ACTION_COPY_OR_MOVE);
var transferable = dtde.getTransferable();
var files = (List<File>) transferable.getTransferData(DataFlavor.javaFileListFlavor);
if (files.isEmpty()) {
return;
}
var file = files.get(0);
Image image = ImageIO.read(file);
selectedFile = file;
int imageWidth = ((BufferedImage)image).getWidth();
int imageHeight = ((BufferedImage)image).getHeight();
if (imageWidth > 500 || imageHeight > 400) {
double scale = Math.min(500.0 / imageWidth, 400.0 / imageHeight);
imageWidth *= scale;
imageHeight *= scale;
image = image.getScaledInstance(imageWidth, imageHeight, Image.SCALE_SMOOTH);
}
// var icon = new ImageIcon(image.getScaledInstance(500, 400, Image.SCALE_SMOOTH));
var icon = new ImageIcon(image);
label.setIcon(icon);
label.setText("");
} catch (Exception e) {
e.printStackTrace();
}
}
}
static class ConstraintsBuilder {
private final GridBagConstraints constraints = new GridBagConstraints();
public ConstraintsBuilder gridx(int gridx) {
constraints.gridx = gridx;
return this;
}
public ConstraintsBuilder gridy(int gridy) {
constraints.gridy = gridy;
return this;
}
public ConstraintsBuilder gridwidth(int gridwidth) {
constraints.gridwidth = gridwidth;
return this;
}
public ConstraintsBuilder gridheight(int gridheight) {
constraints.gridheight = gridheight;
return this;
}
public ConstraintsBuilder weightx(double weightx) {
constraints.weightx = weightx;
return this;
}
public ConstraintsBuilder weighty(double weighty) {
constraints.weighty = weighty;
return this;
}
public ConstraintsBuilder anchor(int anchor) {
constraints.anchor = anchor;
return this;
}
public ConstraintsBuilder fill(int fill) {
constraints.fill = fill;
return this;
}
public GridBagConstraints build() {
return constraints;
}
}
}
import sys
from transformers import Pix2StructForConditionalGeneration, Pix2StructProcessor
from PIL import Image
model_name = "ybelkada/matcha-chartqa"
model = Pix2StructForConditionalGeneration.from_pretrained(model_name)
model.to("cuda:0")
processor = Pix2StructProcessor.from_pretrained(model_name)
while True:
path = input()
q = input()
image = Image.open(path)
inputs = processor(images=image, text=q, return_tensors="pt").to(0)
predictions = model.generate(**inputs, max_new_tokens=512)
print(processor.decode(predictions[0], skip_special_tokens=True))
@kishida
Copy link
Author

kishida commented Apr 24, 2023

image

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment