-
-
Save sravantit25/e05e90619de224fb48c73d7b1b390d3e to your computer and use it in GitHub Desktop.
Consists of scripts for performing Visual Question Answers on charts using Transformers models
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" | |
A simple linear script which generates underlying data table of a given Chart image using DePlot Transformers model. | |
https://huggingface.co/docs/transformers/main/en/model_doc/deplot | |
Requirements: | |
- Python 3.x | |
- transformers library (pip install transformers) | |
- PIL (Python Imaging Library) (pip install pillow) | |
Inputs: | |
1. Charts or graphs in the form of images, in the format of .PNG of .JPG | |
Output: | |
Prints the generated underlying data table corresponding to the input figure. | |
""" | |
from transformers import AutoProcessor, Pix2StructForConditionalGeneration | |
from PIL import Image | |
# Define the path of the input image containing the bar chart | |
IMAGE_PATH = "bar-sample.png" | |
image = Image.open(IMAGE_PATH) | |
# Load the model and processor | |
model = Pix2StructForConditionalGeneration.from_pretrained("google/deplot") | |
processor = AutoProcessor.from_pretrained("google/deplot") | |
# Prepare the inputs for the model, combining the image and text prompt | |
inputs = processor(images=image, text="Generate underlying data table of the figure below:", return_tensors="pt") | |
# Generate the underlying data table using the model | |
predictions = model.generate(**inputs, max_new_tokens=512) | |
table_data = processor.decode(predictions[0], skip_special_tokens=True) | |
# Replace any special tokens in the generated table for better readability | |
table_data = table_data.replace("<0x0A>", " ").replace("<0x0D>", ",") | |
# Print the generated underlying data table for the provided chart | |
print(table_data) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" | |
A simple linear script to use the Pix2Struct models to answer questions based on a given image and a CSV file containing questions and expected answers. | |
This works for google/matcha-chartqa and google/pix2struct-chartqa-base | |
https://huggingface.co/google/matcha-chartqa | |
Requirements: | |
- Python 3.x | |
- transformers library (pip install transformers) | |
- PIL (Python Imaging Library) (pip install pillow) | |
Inputs: | |
1. Charts or graphs in the form of images, in the format of .PNG of .JPG | |
2. The CSV file should have each row containing a question in the first column and the corresponding expected answer in the second column. | |
Output: | |
For each question in the CSV file, the script will print the question, the model-generated answer, and the expected answer. | |
""" | |
import csv | |
from PIL import Image | |
from transformers import Pix2StructForConditionalGeneration, Pix2StructProcessor | |
# Define the path of the input image containing the bar chart | |
IMAGE_PATH = "bar_chart_simple.png" | |
image = Image.open(IMAGE_PATH) | |
# Load the model and processor | |
model = Pix2StructForConditionalGeneration.from_pretrained("google/pix2struct-chartqa-base") | |
processor = Pix2StructProcessor.from_pretrained("google/pix2struct-chartqa-base") | |
# Open the CSV file containing questions and expected answers | |
with open('questions.csv', newline='',encoding='utf-8') as csvfile: | |
reader = csv.reader(csvfile) | |
for row in reader: | |
try: | |
question = row[0] | |
expected_answer = row[1] | |
except IndexError: | |
continue | |
# Prepare the inputs for the model, combining the image and the question text | |
inputs = processor(images=image, text=question, return_tensors="pt") | |
# Generate predictions | |
predictions = model.generate(**inputs, max_new_tokens=50) | |
answer = processor.decode(predictions[0], skip_special_tokens=True) | |
# Print the question, model-generated answer, and expected answer for each question in the CSV file | |
print(f"Question: {question}\nAnswer: {answer}\nExpected Answer: {expected_answer}\n") | |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" | |
A simple script to use GPT to answer questions based on given input which is a table and a CSV file containing questions and expected answers. | |
Uses gpt-3.5-turbo model. An OpenAI API key is required. | |
Inputs: | |
1. Table data which had been extracted from a Chart (using DePlot model in this case). This should be provided as SYSTEM_PROMPT | |
2. The CSV file should have each row containing a question in the first column and the corresponding expected answer in the second column. | |
Output: | |
For each question in the CSV file, the script will print the question, the model-generated answer, and the expected answer. | |
""" | |
import os | |
import csv | |
import openai | |
MODEL_ENGINE = "gpt-3.5-turbo" | |
openai.api_key = os.getenv("API_KEY") | |
def get_response(input_message, system_prompt, max_tokens=None): | |
""" | |
Return the ChatGPT response. | |
Parameters: | |
- input_message (str): The user's message to the model. | |
- system_prompt (str): The initial prompt provided to the model. | |
- max_tokens (int, optional): Maximum number of tokens in the response. Defaults to None. | |
Returns: | |
- str: The generated response from ChatGPT. | |
""" | |
response = openai.ChatCompletion.create( | |
model=MODEL_ENGINE, | |
messages=[ | |
{"role": "user", "content": input_message}, | |
{"role": "system", "content": system_prompt}, | |
], | |
max_tokens=max_tokens, | |
) | |
reply = f"{response['choices'][0]['message']['content']}" | |
return reply | |
SYSTEM_PROMPT = "TITLE | Investment in AI Years | billion dollars 2017 | 60 2018 | 85 2019 | 101 2020 | 154 2021 | 276 2022 | 175" | |
# Read questions and expected answers from a CSV file | |
with open("questions.csv", newline="",encoding='utf-8') as csvfile: | |
reader = csv.reader(csvfile) | |
for row in reader: | |
try: | |
question = row[0] | |
expected_answer = row[1] | |
except IndexError: | |
continue | |
# Get the response from ChatGPT for the given question | |
gpt_response = get_response(question, SYSTEM_PROMPT, max_tokens=50) | |
print("Question:", question) | |
print("Answer:", gpt_response) | |
print("Actual answer:", expected_answer) | |
print() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
What is the chart type? | Bar | |
---|---|---|
What is the title of the chart? | Investment in AI | |
What is the value for the year 2018? | 85 | |
What is the value for the year 2019? | 101 | |
What is the value for the year 2017? | 60 | |
What is the value for the year 2022? | 175 | |
For which year is the investment highest? | 2021 | |
Is the value for the year 2020 more than the year 2019? | Yes | |
What is the color of the bars in the chart? | Green | |
What is the label or title of the Y-axis? | value | |
What is the label or title of the X-axis? | Years | |
Are all the bars in the chart or graph vertical? | Yes | |
What is the value represented by the smallest bar in the chart? | 60 | |
What does the legend represent or display? | billion dollars | |
How many billion dollars were invested in the year 2021? | 276 |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment