Last active
December 14, 2023 10:46
-
-
Save sravantit25/d1879eddbab15a62b12c4223460d6e3f to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
What is the chart type? | Bar | |
---|---|---|
What is the title of the chart? | Country Summary | |
What is the value for Austria for the year 2004? | 57000 | |
What is the value for Brazil for the year 2003? | 20000 | |
What is the value for France for the year 2004? | 46000 | |
Which color represents the year 2003? | Red | |
Which country has highest value for the year 2004? | Italy | |
For the country USA is the value for the year 2005 more than the year 2004? | Yes | |
What is the label or title of the X axis? | Country | |
Which country has highest value for year 2005? | USA | |
What is the value for France for the year 2005? | 19000 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
question | expected_answer | |
---|---|---|
What is the title of the chart? | Investment in AI | |
What is the chart type? | Bar chart | |
What is the investment in AI for the year 2022? | 175 billion dollars | |
What is the investment in the year 2018? | 85 billion dollars | |
What is the value for the year 2017? | 60 billion dollars | |
What colors are the bars in the chart? | Green | |
For which year is the investment highest? | 2021 | |
Is the value for the year 2020 more than the year 2019? | Yes | |
What is the label or title on the X-axis? | Years | |
What is the label or title on the Y-axis? | value | |
What is the value represented by the smallest bar in the chart? | 60 billion dollars | |
What does the legend represent or display? | billion dollars | |
What is the percentage increase from year 2020 to year 2021 | 79 percentage | |
Are there are 6 bars on the chart? | Yes | |
Are all the bars on the chart vertical? | Yes |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" | |
This script helps to perform Visual Question Answering on charts using GPT-4 Vision model. | |
It reads questions and expected answers from a CSV file, provides them to the model | |
and captures the response. | |
Note: Update the 'API_KEY', 'image_path', and 'file_path' variables with your specific values. | |
""" | |
import os | |
import base64 | |
import csv | |
import sys | |
import requests | |
import openai | |
API_KEY = os.getenv("API_KEY") | |
openai.api_key = API_KEY | |
API_URL = "https://api.openai.com/v1/chat/completions" | |
MODEL = "gpt-4-vision-preview" | |
class ReadCSVError(Exception): | |
"To raise CSV Error" | |
class PayloadGenerationError(Exception): | |
"To raise Payload Generation error" | |
class ResponseParsingError(Exception): | |
"To raise Response Parsing error" | |
def encode_image(image_path): | |
""" | |
Encode an image to base64 | |
Params: | |
- image_path (str): The path to the image file | |
Returns: | |
- str: base64-encoded image data | |
""" | |
try: | |
with open(image_path, "rb") as image_file: | |
return base64.b64encode(image_file.read()).decode("utf-8") | |
except FileNotFoundError as file_not_found_error: | |
raise PayloadGenerationError( | |
"Cannot find the image file", file_not_found_error | |
) from file_not_found_error | |
def prepare_input_questions(questions): | |
""" | |
Prepare the list of questions to be fed to the model | |
Params: | |
- questions (list): List of strings representing questions | |
Returns: | |
- list: List of dictionaries representing the content for the user message | |
""" | |
if not isinstance(questions, list): | |
raise ValueError("Invalid input: 'questions' must be a list.") | |
content = [] | |
for question in questions: | |
content.append({"type": "text", "text": question}) | |
return content | |
def generate_payload(api_key, model, image_path, questions): | |
""" | |
Generate the payload for the OpenAI API | |
Params: | |
- api_key (str): OpenAI API key | |
- model (str): OpenAI model identifier | |
- image_path (str): Path to the image file | |
- questions (list): List of strings representing questions | |
Returns: | |
- tuple: A tuple containing the headers and payload for the OpenAI API | |
""" | |
base64_image = encode_image(image_path) | |
headers = {"Content-Type": "application/json", "Authorization": f"Bearer {api_key}"} | |
# Generate payload | |
try: | |
prepared_questions = prepare_input_questions(questions) | |
except ValueError as value_error: | |
raise PayloadGenerationError(value_error) from value_error | |
payload = { | |
"model": model, | |
"messages": [ | |
{ | |
"role": "user", | |
"content": prepared_questions | |
+ [ | |
{ | |
"type": "image_url", | |
"image_url": {"url": f"data:image/jpeg;base64,{base64_image}"}, | |
} | |
], | |
} | |
], | |
"max_tokens": 300, | |
} | |
return headers, payload | |
def get_response(response_json, questions, expected_answers): | |
""" | |
Get the response from model and print it | |
Params: | |
- response_json (dict): The JSON response from the OpenAI API | |
- questions (list): List of strings representing questions | |
- expected_answers (list): List of strings representing expected answers | |
""" | |
try: | |
assistant_message = response_json["choices"][0]["message"]["content"] | |
assistant_message_lines = assistant_message.split("\n") | |
answers = assistant_message_lines | |
i = 0 | |
for answer in answers: | |
# skip new line between the different answers (of the list of questions) of model's response | |
if answer == "": | |
continue | |
print(f"Question {i}: {questions[i]}") | |
print(f"Answer {i}: {answer}") | |
print(f"Expected Answer {i}: {expected_answers[i]}") | |
print() | |
i += 1 | |
except KeyError as key_error: | |
raise ResponseParsingError(key_error) from key_error | |
except IndexError as index_error: | |
raise ResponseParsingError(index_error) from index_error | |
def process_row(row): | |
""" | |
Process a row from the CSV file and return the question and the expected answer | |
Params: | |
- row (list): A list of strings representing a row from the CSV file | |
Returns: | |
- tuple: A tuple containing the question and the expected answer | |
""" | |
# Check if the row has enough elements | |
if len(row) == 2: | |
question = row[0] | |
expected_answer = row[1] | |
return question, expected_answer | |
raise ReadCSVError("Invalid CSV file structure") | |
def read_csv(file_path): | |
""" | |
Read questions and expected answers from a CSV file | |
Params: | |
- file_path (str): Path to the CSV file | |
Returns: | |
- tuple: A tuple containing two lists - 'questions' and 'expected_answers' | |
""" | |
questions = [] | |
expected_answers = [] | |
try: | |
with open(file_path, "r", encoding="utf-8") as csv_file: | |
reader = csv.reader(csv_file) | |
next(reader) | |
for row in reader: | |
# Process each row and append the results to the lists | |
question, expected_answer = process_row(row) | |
questions.append(question) | |
expected_answers.append(expected_answer) | |
except FileNotFoundError as file_not_found_error: | |
raise ReadCSVError( | |
f"The specified CSV file is not found: {file_not_found_error}" | |
) from file_not_found_error | |
except csv.Error as csv_error: | |
raise ReadCSVError(f"Invalid CSV file: {csv_error}") from csv_error | |
return questions, expected_answers | |
if __name__ == "__main__": | |
IMAGE_PATH = <a jpeg or png image of the chart> | |
FILE_PATH = <csv file consisting of questions> | |
try: | |
input_questions, input_expected_answers = read_csv(FILE_PATH) | |
except ReadCSVError as csvError: | |
print("Error reading CSV file: ", csvError) | |
sys.exit(1) | |
try: | |
input_headers, input_payload = generate_payload( | |
API_KEY, MODEL, IMAGE_PATH, input_questions | |
) | |
except PayloadGenerationError as error: | |
print(f"Error generating payload: {error}") | |
sys.exit(1) | |
response = requests.post( | |
API_URL, headers=input_headers, json=input_payload, timeout=30 | |
) | |
try: | |
get_response(response.json(), input_questions, input_expected_answers) | |
except ResponseParsingError as error: | |
print(f"Error parsing the assistant response {error}") | |
sys.exit(1) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
question | expected_answer | |
---|---|---|
What type of chart is this? | Bar | |
What is label or title on Y-axis? | Turnover in million GBP | |
Are all the bars in the chart or graph vertical? | Yes | |
What is the value of the largest bar in the chart? | 73.4 | |
What is the value for the year 2018? | 68.5 | |
What is the value for the year 2016? | 58.5 | |
What is the turnover for the year 2010? | 33 | |
For which year is the turnover highest? | 2019 | |
Is the turnover for the year 2017 more than the year 2016? | Yes | |
What is the color of the bars in the chart? | Blue |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment