Skip to content

Instantly share code, notes, and snippets.

@sravantit25
Last active December 14, 2023 10:46
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save sravantit25/d1879eddbab15a62b12c4223460d6e3f to your computer and use it in GitHub Desktop.
Save sravantit25/d1879eddbab15a62b12c4223460d6e3f to your computer and use it in GitHub Desktop.
What is the chart type? Bar
What is the title of the chart? Country Summary
What is the value for Austria for the year 2004? 57000
What is the value for Brazil for the year 2003? 20000
What is the value for France for the year 2004? 46000
Which color represents the year 2003? Red
Which country has highest value for the year 2004? Italy
For the country USA is the value for the year 2005 more than the year 2004? Yes
What is the label or title of the X axis? Country
Which country has highest value for year 2005? USA
What is the value for France for the year 2005? 19000
question expected_answer
What is the title of the chart? Investment in AI
What is the chart type? Bar chart
What is the investment in AI for the year 2022? 175 billion dollars
What is the investment in the year 2018? 85 billion dollars
What is the value for the year 2017? 60 billion dollars
What colors are the bars in the chart? Green
For which year is the investment highest? 2021
Is the value for the year 2020 more than the year 2019? Yes
What is the label or title on the X-axis? Years
What is the label or title on the Y-axis? value
What is the value represented by the smallest bar in the chart? 60 billion dollars
What does the legend represent or display? billion dollars
What is the percentage increase from year 2020 to year 2021 79 percentage
Are there are 6 bars on the chart? Yes
Are all the bars on the chart vertical? Yes
"""
This script helps to perform Visual Question Answering on charts using GPT-4 Vision model.
It reads questions and expected answers from a CSV file, provides them to the model
and captures the response.
Note: Update the 'API_KEY', 'image_path', and 'file_path' variables with your specific values.
"""
import os
import base64
import csv
import sys
import requests
import openai
API_KEY = os.getenv("API_KEY")
openai.api_key = API_KEY
API_URL = "https://api.openai.com/v1/chat/completions"
MODEL = "gpt-4-vision-preview"
class ReadCSVError(Exception):
"To raise CSV Error"
class PayloadGenerationError(Exception):
"To raise Payload Generation error"
class ResponseParsingError(Exception):
"To raise Response Parsing error"
def encode_image(image_path):
"""
Encode an image to base64
Params:
- image_path (str): The path to the image file
Returns:
- str: base64-encoded image data
"""
try:
with open(image_path, "rb") as image_file:
return base64.b64encode(image_file.read()).decode("utf-8")
except FileNotFoundError as file_not_found_error:
raise PayloadGenerationError(
"Cannot find the image file", file_not_found_error
) from file_not_found_error
def prepare_input_questions(questions):
"""
Prepare the list of questions to be fed to the model
Params:
- questions (list): List of strings representing questions
Returns:
- list: List of dictionaries representing the content for the user message
"""
if not isinstance(questions, list):
raise ValueError("Invalid input: 'questions' must be a list.")
content = []
for question in questions:
content.append({"type": "text", "text": question})
return content
def generate_payload(api_key, model, image_path, questions):
"""
Generate the payload for the OpenAI API
Params:
- api_key (str): OpenAI API key
- model (str): OpenAI model identifier
- image_path (str): Path to the image file
- questions (list): List of strings representing questions
Returns:
- tuple: A tuple containing the headers and payload for the OpenAI API
"""
base64_image = encode_image(image_path)
headers = {"Content-Type": "application/json", "Authorization": f"Bearer {api_key}"}
# Generate payload
try:
prepared_questions = prepare_input_questions(questions)
except ValueError as value_error:
raise PayloadGenerationError(value_error) from value_error
payload = {
"model": model,
"messages": [
{
"role": "user",
"content": prepared_questions
+ [
{
"type": "image_url",
"image_url": {"url": f"data:image/jpeg;base64,{base64_image}"},
}
],
}
],
"max_tokens": 300,
}
return headers, payload
def get_response(response_json, questions, expected_answers):
"""
Get the response from model and print it
Params:
- response_json (dict): The JSON response from the OpenAI API
- questions (list): List of strings representing questions
- expected_answers (list): List of strings representing expected answers
"""
try:
assistant_message = response_json["choices"][0]["message"]["content"]
assistant_message_lines = assistant_message.split("\n")
answers = assistant_message_lines
i = 0
for answer in answers:
# skip new line between the different answers (of the list of questions) of model's response
if answer == "":
continue
print(f"Question {i}: {questions[i]}")
print(f"Answer {i}: {answer}")
print(f"Expected Answer {i}: {expected_answers[i]}")
print()
i += 1
except KeyError as key_error:
raise ResponseParsingError(key_error) from key_error
except IndexError as index_error:
raise ResponseParsingError(index_error) from index_error
def process_row(row):
"""
Process a row from the CSV file and return the question and the expected answer
Params:
- row (list): A list of strings representing a row from the CSV file
Returns:
- tuple: A tuple containing the question and the expected answer
"""
# Check if the row has enough elements
if len(row) == 2:
question = row[0]
expected_answer = row[1]
return question, expected_answer
raise ReadCSVError("Invalid CSV file structure")
def read_csv(file_path):
"""
Read questions and expected answers from a CSV file
Params:
- file_path (str): Path to the CSV file
Returns:
- tuple: A tuple containing two lists - 'questions' and 'expected_answers'
"""
questions = []
expected_answers = []
try:
with open(file_path, "r", encoding="utf-8") as csv_file:
reader = csv.reader(csv_file)
next(reader)
for row in reader:
# Process each row and append the results to the lists
question, expected_answer = process_row(row)
questions.append(question)
expected_answers.append(expected_answer)
except FileNotFoundError as file_not_found_error:
raise ReadCSVError(
f"The specified CSV file is not found: {file_not_found_error}"
) from file_not_found_error
except csv.Error as csv_error:
raise ReadCSVError(f"Invalid CSV file: {csv_error}") from csv_error
return questions, expected_answers
if __name__ == "__main__":
IMAGE_PATH = <a jpeg or png image of the chart>
FILE_PATH = <csv file consisting of questions>
try:
input_questions, input_expected_answers = read_csv(FILE_PATH)
except ReadCSVError as csvError:
print("Error reading CSV file: ", csvError)
sys.exit(1)
try:
input_headers, input_payload = generate_payload(
API_KEY, MODEL, IMAGE_PATH, input_questions
)
except PayloadGenerationError as error:
print(f"Error generating payload: {error}")
sys.exit(1)
response = requests.post(
API_URL, headers=input_headers, json=input_payload, timeout=30
)
try:
get_response(response.json(), input_questions, input_expected_answers)
except ResponseParsingError as error:
print(f"Error parsing the assistant response {error}")
sys.exit(1)
question expected_answer
What type of chart is this? Bar
What is label or title on Y-axis? Turnover in million GBP
Are all the bars in the chart or graph vertical? Yes
What is the value of the largest bar in the chart? 73.4
What is the value for the year 2018? 68.5
What is the value for the year 2016? 58.5
What is the turnover for the year 2010? 33
For which year is the turnover highest? 2019
Is the turnover for the year 2017 more than the year 2016? Yes
What is the color of the bars in the chart? Blue
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment