Skip to content

Instantly share code, notes, and snippets.

@felipenunezb
Created August 7, 2020 08:56
Show Gist options
  • Save felipenunezb/c32a0fad4271cb8b687b27cb61599b70 to your computer and use it in GitHub Desktop.
Save felipenunezb/c32a0fad4271cb8b687b27cb61599b70 to your computer and use it in GitHub Desktop.
script to convert NewsQA datafile into Squad format
import os
import tqdm
import json
import argparse
class NewsQAPreprocessor:
def __init__(self, path, filename):
self.filename = filename
self.path = path
self.data = None
def load_data(self):
filepath = os.path.join(self.path, self.filename)
with open(filepath, encoding="utf-8") as f:
self.data = json.load(f)
def split_data(self):
self.load_data()
envs = ["train", "dev", "test"]
qid = 1 #to act as question id in squad format
#dictionaries as final files
train_newsqa = {}
dev_newsqa = {}
test_newsqa = {}
dicts = [train_newsqa, dev_newsqa, test_newsqa]
train_data = []
dev_data = []
test_data = []
# loop over the data
for article in tqdm.tqdm(self.data["data"]):
data = {}
context = article["text"]
title = article["storyId"]
paragraph_list = []
paragraph = {}
qas_list = []
for question in article["questions"]:
qas = {}
ans_list = []
q = question["q"].strip()
#impossible or not
if question.get("isAnswerAbsent") == 0 and question["consensus"].get("s"):
is_impossible = False
for answer in question["answers"]:
ans = {}
s = answer["sourcerAnswers"][0].get("s")
e = answer["sourcerAnswers"][0].get("e")
ans_text = context[s:e].strip(".| ").strip("\n")
ans["text"] = ans_text
ans["answer_start"] = s
ans_list.append(ans)
else:
is_impossible = True
qas["question"] = q
qas["id"] = qid
qas["answers"] = ans_list
qas["is_impossible"] = is_impossible
qas_list.append(qas)
qid += 1
paragraph["qas"] = qas_list
paragraph["context"] = context
paragraph_list.append(paragraph)
data["title"] = title
data["paragraphs"] = paragraph_list
if article["type"] == 'train':
train_data.append(data)
elif article["type"] == 'dev':
dev_data.append(data)
elif article["type"] == 'test':
test_data.append(data)
else:
continue
train_newsqa["version"] = "newsqa"
dev_newsqa["version"] = "newsqa"
test_newsqa["version"] = "newsqa"
train_newsqa["data"] = train_data
dev_newsqa["data"] = dev_data
test_newsqa["data"] = test_data
for n, env in enumerate(envs):
self.write_data(dicts[n], os.path.join(env + "_newsqa.json"))
def preprocess(self):
self.split_data()
def write_data(self, data, file_path):
with open(os.path.join(file_path), 'w', encoding="utf-8") as write_file:
json.dump(data, write_file, indent=4)
def main():
parser = argparse.ArgumentParser()
## Required parameters
parser.add_argument("--file_location", default="", type=str,
help="Dataset location path. If executed in same directory as the file, no need to specify it.")
parser.add_argument("--filename", default="combined-newsqa-data-v1.json", type=str,
help="Combined newsqa dataset name")
args = parser.parse_args()
nqa_p = NewsQAPreprocessor(args.file_location, args.filename)
nqa_p.preprocess()
if __name__ == "__main__":
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment