Skip to content

Instantly share code, notes, and snippets.

@hamletbatista
Created September 4, 2020 23:52
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save hamletbatista/a395a150ad493cf1f2bc44cbc42cfc66 to your computer and use it in GitHub Desktop.
Save hamletbatista/a395a150ad493cf1f2bc44cbc42cfc66 to your computer and use it in GitHub Desktop.
class QueGenerator():
def __init__(self):
self.que_model = T5ForConditionalGeneration.from_pretrained('./t5_que_gen_model/t5_base_que_gen/')
self.ans_model = T5ForConditionalGeneration.from_pretrained('./t5_ans_gen_model/t5_base_ans_gen/')
self.que_tokenizer = T5Tokenizer.from_pretrained('./t5_que_gen_model/t5_base_tok_que_gen/')
self.ans_tokenizer = T5Tokenizer.from_pretrained('./t5_ans_gen_model/t5_base_tok_ans_gen/')
self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
self.que_model = self.que_model.to(self.device)
self.ans_model = self.ans_model.to(self.device)
def generate(self, text):
answers = self._get_answers(text)
questions = self._get_questions(text, answers)
output = [{'answer': ans, 'question': que} for ans, que in zip(answers, questions)]
return output
def _get_answers(self, text):
# split into sentences
sents = sent_tokenize(text)
examples = []
for i in range(len(sents)):
input_ = ""
for j, sent in enumerate(sents):
if i == j:
sent = "[HL] %s [HL]" % sent
input_ = "%s %s" % (input_, sent)
input_ = input_.strip()
input_ = input_ + " </s>"
examples.append(input_)
batch = self.ans_tokenizer.batch_encode_plus(examples, max_length=512, pad_to_max_length=True, return_tensors="pt")
with torch.no_grad():
outs = self.ans_model.generate(input_ids=batch['input_ids'].to(self.device),
attention_mask=batch['attention_mask'].to(self.device),
max_length=32,
# do_sample=False,
# num_beams = 4,
)
dec = [self.ans_tokenizer.decode(ids, skip_special_tokens=False) for ids in outs]
answers = [item.split('[SEP]') for item in dec]
answers = chain(*answers)
answers = [ans.strip() for ans in answers if ans != ' ']
return answers
def _get_questions(self, text, answers):
examples = []
for ans in answers:
input_text = "%s [SEP] %s </s>" % (ans, text)
examples.append(input_text)
batch = self.que_tokenizer.batch_encode_plus(examples, max_length=512, pad_to_max_length=True, return_tensors="pt")
with torch.no_grad():
outs = self.que_model.generate(input_ids=batch['input_ids'].to(self.device),
attention_mask=batch['attention_mask'].to(self.device),
max_length=32,
num_beams = 4)
dec = [self.que_tokenizer.decode(ids, skip_special_tokens=False) for ids in outs]
return dec
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment