Skip to content

Instantly share code, notes, and snippets.

@tezansahu
Last active February 12, 2022 14:25
Show Gist options
  • Save tezansahu/f983361a53866ce0f40c1dd543584aec to your computer and use it in GitHub Desktop.
Save tezansahu/f983361a53866ce0f40c1dd543584aec to your computer and use it in GitHub Desktop.
# Load the training & evaluation dataset present in CSV format
dataset = load_dataset(
"csv",
data_files={
"train": os.path.join("dataset", "data_train.csv"),
"test": os.path.join("dataset", "data_eval.csv")
}
)
# Load the space of all possible answers
with open(os.path.join("dataset", "answer_space.txt")) as f:
answer_space = f.read().splitlines()
# Since we model the VQA task as a multiclass classification problem,
# we need to create the labels from the actual answers
dataset = dataset.map(
lambda examples: {
'label': [
# Select the 1st answer if multiple answers are provided for single question
answer_space.index(ans.replace(" ", "").split(",")[0])
for ans in examples['answer']
]
},
batched=True
)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment