Skip to content

Instantly share code, notes, and snippets.

@shahules786
Created May 5, 2023 13:34
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save shahules786/0c16346e940c0e093f3aadf10c8eac96 to your computer and use it in GitHub Desktop.
Save shahules786/0c16346e940c0e093f3aadf10c8eac96 to your computer and use it in GitHub Desktop.
class WebGPT:
name = "openai/webgpt_comparisons"
def __init__(self, split: str = "train"):
super().__init__()
self.split = split
dataset = load_dataset(self.name, split=self.split)
self.dataset_dict = defaultdict(dict)
for item in dataset:
post_id = item["question"]["id"]
if post_id not in self.dataset_dict.keys():
self.dataset_dict[post_id] = {
"full_text": item["question"]["full_text"],
"answers": [],
}
if item["score_0"] > 0:
answers = [item["answer_0"], item["answer_1"]]
elif item["score_0"] < 0:
answers = [item["answer_1"], item["answer_0"]]
else:
answers = []
answers = [re.sub(r"\[\d+\]", "", answer) for answer in answers]
answers = [
".".join([sent.strip() for sent in answer.split(".")])
for answer in answers
]
if answers:
self.dataset_dict[post_id]["answers"].extend(answers)
else:
_ = self.dataset_dict.pop(post_id)
self.post_ids = list(self.dataset_dict.keys())
def __len__(self):
return len(self.post_ids)
def __getitem__(self, idx):
question, answers = self.dataset_dict[self.post_ids[idx]].values()
return question, answers
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment