Skip to content

Instantly share code, notes, and snippets.

@martingaido
Created November 9, 2022 00:01
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save martingaido/00d44ac7b9001c53bc8df1d5b2af1ad0 to your computer and use it in GitHub Desktop.
Save martingaido/00d44ac7b9001c53bc8df1d5b2af1ad0 to your computer and use it in GitHub Desktop.
Generating Human-level Text with Contrastive Search in Transformers
# pip install torch
# pip install "transformers==4.24.0"
# usage: python text-generator-transformers.py 'some text'
import sys
import torch
from transformers import GPT2Tokenizer, GPT2LMHeadModel
arg = sys.argv
print('')
print('Wait, this process may take a while...')
print('')
model_name = 'gpt2-large'
tokenizer = GPT2Tokenizer.from_pretrained(model_name)
model = GPT2LMHeadModel.from_pretrained(model_name, pad_token_id=tokenizer.eos_token_id)
model.eval()
# prepare the prefix
prefix_text = arg[1]
input_ids = tokenizer(prefix_text, return_tensors='pt').input_ids
# generate the result with contrastive search
output = model.generate(input_ids, penalty_alpha=0.6, top_k=4, max_length=512)
print("Output:\n" + 100 * '-')
print(tokenizer.decode(output[0], skip_special_tokens=True))
print("" + 100 * '-')
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment