Skip to content

Instantly share code, notes, and snippets.

@joelgrus
Last active March 2, 2019 20:47
Show Gist options
  • Star 5 You must be signed in to star a gist
  • Fork 1 You must be signed in to fork a gist
  • Save joelgrus/73aed719ce894ecd9d21e0c72fba3e43 to your computer and use it in GitHub Desktop.
Save joelgrus/73aed719ce894ecd9d21e0c72fba3e43 to your computer and use it in GitHub Desktop.
sample from the huggingface implementation of openai gpt2
# pip install pytorch-pretrained-bert>=0.6
from pytorch_pretrained_bert.tokenization_gpt2 import GPT2Tokenizer
from pytorch_pretrained_bert.modeling_gpt2 import GPT2LMHeadModel
import torch
tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
model = GPT2LMHeadModel.from_pretrained('gpt2')
# The end of text marker.
END_OF_TEXT = tokenizer.encoder["<|endoftext|>"]
SEED = "Twitter is"
def generate(seed: str = SEED, num_steps: int = 20) -> str:
token_ids = tokenizer.encode(seed)
# Last value of hidden states
presents = None
# Input ids
inputs = torch.LongTensor([token_ids])
for _ in range(num_steps):
# Run model
logits, presents = model.forward(inputs, past=presents)
# Sample from logits
d = torch.distributions.Categorical(logits=logits[0, -1])
next_id = d.sample().item()
if next_id == END_OF_TEXT:
break
token_ids.append(next_id)
inputs = torch.LongTensor([[next_id]])
# Decode
return tokenizer.decode(token_ids)
print(generate(seed=SEED, num_steps=50))
@JulesGM
Copy link

JulesGM commented Mar 2, 2019

add a temperature argument to add more diversity. def generate(seed: str = SEED, num_steps: int = 20, temperature: float = 1.0) -> str: and torch.distributions.Categorical(logits=logits[0, -1] / temperature)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment