Skip to content

Instantly share code, notes, and snippets.

@tezansahu
Last active January 25, 2023 10:48
Show Gist options
  • Save tezansahu/caed8c19cac5a840eba9f6b0326aa371 to your computer and use it in GitHub Desktop.
Save tezansahu/caed8c19cac5a840eba9f6b0326aa371 to your computer and use it in GitHub Desktop.
# Define a regex pattern to normalize the question &
# find the image ID for which the question is asked
image_pattern = re.compile("( (in |on |of )?(the |this )?(image\d*) \?)")
with open(os.path.join("dataset", "all_qa_pairs.txt")) as f:
qa_data = [x.replace("\n", "") for x in f.readlines()]
df = pd.DataFrame({"question": [], "answer": [], "image_id":[]})
for i in range(0, len(qa_data), 2):
img_id = image_pattern.findall(qa_data[i])[0][3]
question = qa_data[i].replace(image_pattern.findall(qa_data[i])[0][0], "")
record = {
"question": question,
"answer": qa_data[i+1],
"image_id": img_id,
}
df = df.append(record, ignore_index=True)
# Create a list of all possible answers, so that the answer generation part of the VQA task
# can be modelled as multiclass classification
answer_space = []
for ans in df.answer.to_list():
answer_space = answer_space + [ans] if "," not in ans else answer_space + ans.replace(" ", "").split(",")
answer_space = list(set(answer_space))
answer_space.sort()
with open(os.path.join("dataset", "answer_space.txt"), "w") as f:
f.writelines("\n".join(answer_space))
# Since the actual dataset contains only ~54% of the data for training (very less),
# we produce our own splits for training & evaluation with 80% data being used for training
train_df, test_df = train_test_split(df, test_size=0.2, random_state=42)
train_df.to_csv(os.path.join("dataset", "data_train.csv"), index=None)
test_df.to_csv(os.path.join("dataset", "data_eval.csv"), index=None)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment