Skip to content

Instantly share code, notes, and snippets.

@preetum
Created May 5, 2020 05:23
Show Gist options
  • Save preetum/f69f7287dc4b34f8c71bf45133dc96c7 to your computer and use it in GitHub Desktop.
Save preetum/f69f7287dc4b34f8c71bf45133dc96c7 to your computer and use it in GitHub Desktop.
from transformers import AutoModelWithLMHead, AutoTokenizer
import logging
model_name = 'gpt2-xl'
dev = 'cuda'
#dev = 'cpu'
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelWithLMHead.from_pretrained(model_name).to(dev)
logging.getLogger().setLevel(logging.ERROR)
def generate(context, max_length=50):
input_ids = tokenizer.encode(context, return_tensors="pt").to(dev)
output_sequences = model.generate(
input_ids=input_ids,
max_length = max_length,
#temperature=args.temperature,
#top_k=args.k,
#top_p=args.p,
#repetition_penalty=args.repetition_penalty,
do_sample=True,
num_return_sequences=1,
)
text = tokenizer.decode(output_sequences.tolist()[0], clean_up_tokenization_spaces=True, skip_special_tokens=True)
return text
def run_interactive(max_length=50):
inp = input("Enter context to be completed (or quit): ")
while inp != 'quit':
for i in range(2):
print("=== GENERATED SEQUENCE {} ===".format(i + 1))
print(generate(inp, max_length=max_length))
inp = input("Enter context to be completed (or quit): ")
run_interactive()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment