Created
April 24, 2023 10:54
-
-
Save kishida/ce97494f98871f0f84c8f6c6b4553886 to your computer and use it in GitHub Desktop.
グラフ説明するくん
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
package naoki.slm; | |
import javax.imageio.ImageIO; | |
import javax.swing.*; | |
import java.awt.*; | |
import java.awt.datatransfer.DataFlavor; | |
import java.awt.dnd.DnDConstants; | |
import java.awt.dnd.DropTarget; | |
import java.awt.dnd.DropTargetDropEvent; | |
import java.awt.image.BufferedImage; | |
import java.io.*; | |
import java.util.List; | |
public class ChartDescribeForm { | |
static File selectedFile; | |
public static void main(String[] args) throws IOException { | |
// フォームの作成 | |
var frame = new JFrame("グラフの説明"); | |
frame.setDefaultCloseOperation(JFrame.EXIT_ON_CLOSE); | |
frame.setSize(800, 600); | |
frame.setLayout(new GridBagLayout()); | |
var input = new JTextArea(); | |
Font font = new Font(Font.SANS_SERIF, Font.PLAIN, 20); | |
input.setFont(font); | |
input.setLineWrap(true); | |
frame.add(new JScrollPane(input), new ConstraintsBuilder() | |
.gridx(0).gridy(0).fill(GridBagConstraints.BOTH) | |
.weightx(0.7).weighty(0.2).build()); | |
var button = new JButton("グラフを説明"); | |
frame.add(button, new ConstraintsBuilder() | |
.gridx(1).gridy(0).fill(GridBagConstraints.HORIZONTAL) | |
.weightx(0.3).weighty(0.2).build()); | |
// グラフを表示するラベルの作成 | |
var label = new JLabel("drop image here"); | |
label.setDropTarget(new ImageDropTarget(label)); | |
frame.add(label, new ConstraintsBuilder() | |
.gridx(0).gridy(1).gridwidth(2).fill(GridBagConstraints.BOTH) | |
.weightx(1.0).weighty(0.4).build()); | |
var output = new JTextArea(); | |
output.setFont(font); | |
output.setLineWrap(true); | |
frame.add(new JScrollPane(output), new ConstraintsBuilder() | |
.gridx(0).gridy(2).gridwidth(2).fill(GridBagConstraints.BOTH) | |
.weightx(1.0).weighty(0.3).build()); | |
button.addActionListener(e -> { | |
var text = input.getText(); | |
if (text.isEmpty()) { | |
return; | |
} | |
output.append("Question: " + text + "\n"); | |
var result = describeGraph(selectedFile, text); | |
output.append("Answer: " + result + "\n"); | |
}); | |
// pythonの実行 | |
ProcessBuilder pb = new ProcessBuilder(PYTHON_PATH, SCRIPT_PATH); | |
// 出力を受け取る | |
pb.redirectErrorStream(false); | |
Process p = pb.start(); | |
// 入力を渡す | |
OutputStream os = p.getOutputStream(); | |
pyIn = new PrintWriter(os); | |
// 出力を受け取る | |
InputStream is = p.getInputStream(); | |
pyOut = new BufferedReader(new InputStreamReader(is)); | |
frame.setVisible(true); | |
} | |
static final String PYTHON_PATH = "PATH to python"; | |
static final String SCRIPT_PATH = "PATH to chartdescriber.py"; | |
static PrintWriter pyIn; | |
static BufferedReader pyOut; | |
static String describeGraph(File path, String question) { | |
if (path == null) { | |
return "no image"; | |
} | |
if (question.isEmpty()) { | |
return "no question"; | |
} | |
try { | |
// send the image path | |
pyIn.println(path.getAbsolutePath()); | |
// send the question | |
pyIn.println(question); | |
pyIn.flush(); | |
// output the result | |
return pyOut.readLine(); | |
} catch (Exception e) { | |
e.printStackTrace(); | |
} | |
return "not implemented yet"; | |
} | |
static class ImageDropTarget extends DropTarget { | |
JLabel label; | |
public ImageDropTarget(JLabel label) { | |
this.label = label; | |
} | |
@Override | |
public void drop(DropTargetDropEvent dtde) { | |
try { | |
// ドロップされたファイルを取得 | |
dtde.acceptDrop(DnDConstants.ACTION_COPY_OR_MOVE); | |
var transferable = dtde.getTransferable(); | |
var files = (List<File>) transferable.getTransferData(DataFlavor.javaFileListFlavor); | |
if (files.isEmpty()) { | |
return; | |
} | |
var file = files.get(0); | |
Image image = ImageIO.read(file); | |
selectedFile = file; | |
int imageWidth = ((BufferedImage)image).getWidth(); | |
int imageHeight = ((BufferedImage)image).getHeight(); | |
if (imageWidth > 500 || imageHeight > 400) { | |
double scale = Math.min(500.0 / imageWidth, 400.0 / imageHeight); | |
imageWidth *= scale; | |
imageHeight *= scale; | |
image = image.getScaledInstance(imageWidth, imageHeight, Image.SCALE_SMOOTH); | |
} | |
// var icon = new ImageIcon(image.getScaledInstance(500, 400, Image.SCALE_SMOOTH)); | |
var icon = new ImageIcon(image); | |
label.setIcon(icon); | |
label.setText(""); | |
} catch (Exception e) { | |
e.printStackTrace(); | |
} | |
} | |
} | |
static class ConstraintsBuilder { | |
private final GridBagConstraints constraints = new GridBagConstraints(); | |
public ConstraintsBuilder gridx(int gridx) { | |
constraints.gridx = gridx; | |
return this; | |
} | |
public ConstraintsBuilder gridy(int gridy) { | |
constraints.gridy = gridy; | |
return this; | |
} | |
public ConstraintsBuilder gridwidth(int gridwidth) { | |
constraints.gridwidth = gridwidth; | |
return this; | |
} | |
public ConstraintsBuilder gridheight(int gridheight) { | |
constraints.gridheight = gridheight; | |
return this; | |
} | |
public ConstraintsBuilder weightx(double weightx) { | |
constraints.weightx = weightx; | |
return this; | |
} | |
public ConstraintsBuilder weighty(double weighty) { | |
constraints.weighty = weighty; | |
return this; | |
} | |
public ConstraintsBuilder anchor(int anchor) { | |
constraints.anchor = anchor; | |
return this; | |
} | |
public ConstraintsBuilder fill(int fill) { | |
constraints.fill = fill; | |
return this; | |
} | |
public GridBagConstraints build() { | |
return constraints; | |
} | |
} | |
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import sys | |
from transformers import Pix2StructForConditionalGeneration, Pix2StructProcessor | |
from PIL import Image | |
model_name = "ybelkada/matcha-chartqa" | |
model = Pix2StructForConditionalGeneration.from_pretrained(model_name) | |
model.to("cuda:0") | |
processor = Pix2StructProcessor.from_pretrained(model_name) | |
while True: | |
path = input() | |
q = input() | |
image = Image.open(path) | |
inputs = processor(images=image, text=q, return_tensors="pt").to(0) | |
predictions = model.generate(**inputs, max_new_tokens=512) | |
print(processor.decode(predictions[0], skip_special_tokens=True)) |
Author
kishida
commented
Apr 24, 2023
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment