Skip to content

Instantly share code, notes, and snippets.

@tezansahu
Last active April 23, 2022 21:13
Show Gist options
  • Save tezansahu/903eb3a625eecd2dd6ebc15fbbe0bb7c to your computer and use it in GitHub Desktop.
Save tezansahu/903eb3a625eecd2dd6ebc15fbbe0bb7c to your computer and use it in GitHub Desktop.
import torch
from transformers import PegasusForConditionalGeneration, PegasusTokenizer
from sentence_splitter import SentenceSplitter
class PegasusParaphraser:
def __init__(self, num_beams=10):
if(torch.cuda.is_available()):
self.device = torch.device("cuda:0")
else:
self.device = torch.device("cpu:0")
# Pegasus Tokenizer & Model for Paraphrasing
print("Loading Pegasus Tokenizer & Model for Paraphrasing.")
paraphraser_model_name = "tuner007/pegasus_paraphrase"
self.tokenizer = PegasusTokenizer.from_pretrained(paraphraser_model_name)
self.model = PegasusForConditionalGeneration.from_pretrained(paraphraser_model_name).to(self.device)
self.num_beams = num_beams
# To split the paragraph into individual sentences
self.splitter = SentenceSplitter(language='en')
def paraphrase_text(self, text):
sentence_list = self.splitter.split(text)
batch = self.tokenizer(sentence_list,truncation=True, padding='longest', max_length=100, return_tensors="pt").to(self.device)
translated = self.model.generate(**batch, max_length=60, num_beams=self.num_beams, num_return_sequences=1, temperature=1.5)
tgt_text = self.tokenizer.batch_decode(translated, skip_special_tokens=True)
paraphrased_text = " ".join(tgt_text)
return paraphrased_text
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment