Skip to content

Instantly share code, notes, and snippets.

@sravantit25
Last active July 28, 2023 11:57
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save sravantit25/e05e90619de224fb48c73d7b1b390d3e to your computer and use it in GitHub Desktop.
Save sravantit25/e05e90619de224fb48c73d7b1b390d3e to your computer and use it in GitHub Desktop.
Consists of scripts for performing Visual Question Answers on charts using Transformers models
"""
A simple linear script which generates underlying data table of a given Chart image using DePlot Transformers model.
https://huggingface.co/docs/transformers/main/en/model_doc/deplot
Requirements:
- Python 3.x
- transformers library (pip install transformers)
- PIL (Python Imaging Library) (pip install pillow)
Inputs:
1. Charts or graphs in the form of images, in the format of .PNG of .JPG
Output:
Prints the generated underlying data table corresponding to the input figure.
"""
from transformers import AutoProcessor, Pix2StructForConditionalGeneration
from PIL import Image
# Define the path of the input image containing the bar chart
IMAGE_PATH = "bar-sample.png"
image = Image.open(IMAGE_PATH)
# Load the model and processor
model = Pix2StructForConditionalGeneration.from_pretrained("google/deplot")
processor = AutoProcessor.from_pretrained("google/deplot")
# Prepare the inputs for the model, combining the image and text prompt
inputs = processor(images=image, text="Generate underlying data table of the figure below:", return_tensors="pt")
# Generate the underlying data table using the model
predictions = model.generate(**inputs, max_new_tokens=512)
table_data = processor.decode(predictions[0], skip_special_tokens=True)
# Replace any special tokens in the generated table for better readability
table_data = table_data.replace("<0x0A>", " ").replace("<0x0D>", ",")
# Print the generated underlying data table for the provided chart
print(table_data)
"""
A simple linear script to use the Pix2Struct models to answer questions based on a given image and a CSV file containing questions and expected answers.
This works for google/matcha-chartqa and google/pix2struct-chartqa-base
https://huggingface.co/google/matcha-chartqa
Requirements:
- Python 3.x
- transformers library (pip install transformers)
- PIL (Python Imaging Library) (pip install pillow)
Inputs:
1. Charts or graphs in the form of images, in the format of .PNG of .JPG
2. The CSV file should have each row containing a question in the first column and the corresponding expected answer in the second column.
Output:
For each question in the CSV file, the script will print the question, the model-generated answer, and the expected answer.
"""
import csv
from PIL import Image
from transformers import Pix2StructForConditionalGeneration, Pix2StructProcessor
# Define the path of the input image containing the bar chart
IMAGE_PATH = "bar_chart_simple.png"
image = Image.open(IMAGE_PATH)
# Load the model and processor
model = Pix2StructForConditionalGeneration.from_pretrained("google/pix2struct-chartqa-base")
processor = Pix2StructProcessor.from_pretrained("google/pix2struct-chartqa-base")
# Open the CSV file containing questions and expected answers
with open('questions.csv', newline='',encoding='utf-8') as csvfile:
reader = csv.reader(csvfile)
for row in reader:
try:
question = row[0]
expected_answer = row[1]
except IndexError:
continue
# Prepare the inputs for the model, combining the image and the question text
inputs = processor(images=image, text=question, return_tensors="pt")
# Generate predictions
predictions = model.generate(**inputs, max_new_tokens=50)
answer = processor.decode(predictions[0], skip_special_tokens=True)
# Print the question, model-generated answer, and expected answer for each question in the CSV file
print(f"Question: {question}\nAnswer: {answer}\nExpected Answer: {expected_answer}\n")
"""
A simple script to use GPT to answer questions based on given input which is a table and a CSV file containing questions and expected answers.
Uses gpt-3.5-turbo model. An OpenAI API key is required.
Inputs:
1. Table data which had been extracted from a Chart (using DePlot model in this case). This should be provided as SYSTEM_PROMPT
2. The CSV file should have each row containing a question in the first column and the corresponding expected answer in the second column.
Output:
For each question in the CSV file, the script will print the question, the model-generated answer, and the expected answer.
"""
import os
import csv
import openai
MODEL_ENGINE = "gpt-3.5-turbo"
openai.api_key = os.getenv("API_KEY")
def get_response(input_message, system_prompt, max_tokens=None):
"""
Return the ChatGPT response.
Parameters:
- input_message (str): The user's message to the model.
- system_prompt (str): The initial prompt provided to the model.
- max_tokens (int, optional): Maximum number of tokens in the response. Defaults to None.
Returns:
- str: The generated response from ChatGPT.
"""
response = openai.ChatCompletion.create(
model=MODEL_ENGINE,
messages=[
{"role": "user", "content": input_message},
{"role": "system", "content": system_prompt},
],
max_tokens=max_tokens,
)
reply = f"{response['choices'][0]['message']['content']}"
return reply
SYSTEM_PROMPT = "TITLE | Investment in AI Years | billion dollars 2017 | 60 2018 | 85 2019 | 101 2020 | 154 2021 | 276 2022 | 175"
# Read questions and expected answers from a CSV file
with open("questions.csv", newline="",encoding='utf-8') as csvfile:
reader = csv.reader(csvfile)
for row in reader:
try:
question = row[0]
expected_answer = row[1]
except IndexError:
continue
# Get the response from ChatGPT for the given question
gpt_response = get_response(question, SYSTEM_PROMPT, max_tokens=50)
print("Question:", question)
print("Answer:", gpt_response)
print("Actual answer:", expected_answer)
print()
What is the chart type? Bar
What is the title of the chart? Investment in AI
What is the value for the year 2018? 85
What is the value for the year 2019? 101
What is the value for the year 2017? 60
What is the value for the year 2022? 175
For which year is the investment highest? 2021
Is the value for the year 2020 more than the year 2019? Yes
What is the color of the bars in the chart? Green
What is the label or title of the Y-axis? value
What is the label or title of the X-axis? Years
Are all the bars in the chart or graph vertical? Yes
What is the value represented by the smallest bar in the chart? 60
What does the legend represent or display? billion dollars
How many billion dollars were invested in the year 2021? 276
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment