Created
November 9, 2022 00:01
-
-
Save martingaido/00d44ac7b9001c53bc8df1d5b2af1ad0 to your computer and use it in GitHub Desktop.
Generating Human-level Text with Contrastive Search in Transformers
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# pip install torch | |
# pip install "transformers==4.24.0" | |
# usage: python text-generator-transformers.py 'some text' | |
import sys | |
import torch | |
from transformers import GPT2Tokenizer, GPT2LMHeadModel | |
arg = sys.argv | |
print('') | |
print('Wait, this process may take a while...') | |
print('') | |
model_name = 'gpt2-large' | |
tokenizer = GPT2Tokenizer.from_pretrained(model_name) | |
model = GPT2LMHeadModel.from_pretrained(model_name, pad_token_id=tokenizer.eos_token_id) | |
model.eval() | |
# prepare the prefix | |
prefix_text = arg[1] | |
input_ids = tokenizer(prefix_text, return_tensors='pt').input_ids | |
# generate the result with contrastive search | |
output = model.generate(input_ids, penalty_alpha=0.6, top_k=4, max_length=512) | |
print("Output:\n" + 100 * '-') | |
print(tokenizer.decode(output[0], skip_special_tokens=True)) | |
print("" + 100 * '-') |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment