Created
June 29, 2021 02:40
-
-
Save danyaljj/a7575a1be088cd1a01be378c9efaca61 to your computer and use it in GitHub Desktop.
prompting gpt2 with "soft" prompts
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from torch.distributions import Categorical | |
from transformers import GPT2Tokenizer, GPT2LMHeadModel | |
import torch | |
import torch.nn.functional as F | |
def embed_inputs(embedding, logits, device='cuda', print_entropy=False): | |
''' | |
embeds inputs in a dense representation, before passing them to the model | |
''' | |
# typically we embed a one-hot vector. But here since we work we work with dense representations, | |
# we have softmax here to make sure that all the values of the input logits sum to one (similar to a 1-hot vector). | |
probs = F.softmax(logits, dim=-1) | |
# probs = logits | |
if print_entropy: | |
_entropy = - probs * torch.log(probs) | |
_entropy = torch.sum(_entropy) | |
print(_entropy) | |
probs = probs.to(device) | |
return torch.matmul(probs, embedding.weight) | |
def _greedy(logits): | |
_, last = torch.topk(logits, k=1, dim=-1) | |
return last | |
def one_hot(tensor, dimension): | |
while len(tensor.shape) < 2: | |
tensor = tensor.unsqueeze(0) | |
onehot = torch.LongTensor(tensor.shape[0], tensor.shape[1], dimension).to(tensor.device) | |
onehot.zero_().scatter_(2, tensor.unsqueeze(-1), 1) | |
return onehot | |
def get_text_from_logits(logits, tokenizer): | |
output_so_far = None | |
last = None | |
logp = 0 | |
for i in range(logits.shape[0]): | |
last = _greedy(logits[i, :]) | |
output_so_far = last if output_so_far is None else torch.cat((output_so_far, last), dim=0) | |
logp += logits[i, :].log_softmax(-1)[last.item()].item() | |
nll = -logp | |
text = tokenizer.decode(output_so_far.tolist()) | |
text = text.replace('\n', ' ') | |
return text, nll, output_so_far | |
model_size = "gpt2" | |
tokenizer = GPT2Tokenizer.from_pretrained(model_size) | |
model = GPT2LMHeadModel.from_pretrained(model_size, output_hidden_states=True) | |
model.to('cuda') | |
model.eval() | |
input_ids = tokenizer.encode("The dog", return_tensors="pt").to('cuda') | |
input_one_hot = one_hot(input_ids, dimension=tokenizer.vocab_size) | |
def decode(model, length, temperature, device): | |
''' | |
GPT2 decoding via dense representations (no arg-max) | |
''' | |
past = None | |
inputs_embeds = None | |
logits_so_far = None | |
for i in range(length): | |
if past is None: | |
# inputs_embeds = model.transformer.wte(input_ids) | |
inputs_embeds = embed_inputs(model.get_input_embeddings(), input_one_hot.type(torch.FloatTensor)/ temperature, device='cuda', print_entropy=True) | |
model_outputs = model(past_key_values=past, inputs_embeds=inputs_embeds) | |
logits = model_outputs.logits | |
past = model_outputs.past_key_values | |
logits = logits[:, -1, :] / temperature | |
logits = logits.unsqueeze(1) | |
logits_so_far = logits if logits_so_far is None else torch.cat((logits_so_far, logits), dim=1) | |
inputs_embeds = embed_inputs(model.get_input_embeddings(), logits, device=device) | |
return logits_so_far | |
for temperature in [0.001, 0.01, 0.1, 0.2, 0.3, 1, 5]: | |
print(f" ------- \n * temperature: {temperature}") | |
logits_so_far = decode(model, 100, temperature, 'cuda') | |
text, nll, _ = get_text_from_logits(logits_so_far[0, :, :], tokenizer) | |
print(text) | |
# should see: | |
# ------- | |
# * temperature: 0.001 | |
# tensor(nan) | |
# was found in a field near the intersection of West and West Streets. The dog was taken to the hospital with a broken collarbone and a broken leg. The dog was taken to the hospital with a broken collarbone and a broken leg. The dog was taken to the hospital with a broken collarbone and a broken leg. The dog was taken to the hospital with a broken collarbone and a broken leg. The dog was taken to the hospital with | |
# ------- | |
# * temperature: 0.01 | |
# tensor(3.8029e-37) | |
# was found in a field near the intersection of West and West Streets. The dog was taken to the hospital with a broken collarbone and a broken leg. The dog was taken to the hospital with a broken collarbone and a broken leg. The dog was taken to the hospital with a broken collarbone and a broken leg. The dog was taken to the hospital with a broken collarbone and a broken leg. The dog was taken to the hospital with | |
# ------- | |
# * temperature: 0.1 | |
# tensor(16.2817) | |
# the dog's head, and the dog's head is the dog's head. (2) If the dog's head is the dog's head, and the dog's head is the dog's head, the dog's head is the dog's head. (3) If the dog's head is the dog's head, and the dog's head is the dog's head, the dog's head is the dog's head. (4) If the dog's head | |
# ------- | |
# * temperature: 0.3 | |
# tensor(21.6467) | |
# G-1-1-1-1-1-1-1-1-1-1-1-1-1-1-1-1-1-1-1-1-1-1-1-1-1-1-1-1-1-1-1-1-1-1-1-1-1-1-1-1-1-1-1-1-1-1-1-1-1- | |
# ------- | |
# * temperature: 1 | |
# tensor(21.6488) | |
# G. J S I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I | |
# ------- | |
# * temperature: 5 | |
# tensor(21.6482) | |
# G J ( ( ( ( ( ( ( ( ( :: :: :: :: The The The The The The The The The The The The The The January January January January January January January January January January January ― ― ― ― ― ― ― ― ― ― ― ― ― ― ― ― ― ― ― ― ― ― ― ― ― ― ― ― ― ― ― ― ― ― ― ― ― ― ― ― ― ― ― ― ― ― ― ― ― ― ― ― ― ― ― ― ― ― ― ― | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment