Skip to content

Instantly share code, notes, and snippets.

Show Gist options
  • Star 1 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save danyaljj/a7575a1be088cd1a01be378c9efaca61 to your computer and use it in GitHub Desktop.
Save danyaljj/a7575a1be088cd1a01be378c9efaca61 to your computer and use it in GitHub Desktop.
prompting gpt2 with "soft" prompts
from torch.distributions import Categorical
from transformers import GPT2Tokenizer, GPT2LMHeadModel
import torch
import torch.nn.functional as F
def embed_inputs(embedding, logits, device='cuda', print_entropy=False):
'''
embeds inputs in a dense representation, before passing them to the model
'''
# typically we embed a one-hot vector. But here since we work we work with dense representations,
# we have softmax here to make sure that all the values of the input logits sum to one (similar to a 1-hot vector).
probs = F.softmax(logits, dim=-1)
# probs = logits
if print_entropy:
_entropy = - probs * torch.log(probs)
_entropy = torch.sum(_entropy)
print(_entropy)
probs = probs.to(device)
return torch.matmul(probs, embedding.weight)
def _greedy(logits):
_, last = torch.topk(logits, k=1, dim=-1)
return last
def one_hot(tensor, dimension):
while len(tensor.shape) < 2:
tensor = tensor.unsqueeze(0)
onehot = torch.LongTensor(tensor.shape[0], tensor.shape[1], dimension).to(tensor.device)
onehot.zero_().scatter_(2, tensor.unsqueeze(-1), 1)
return onehot
def get_text_from_logits(logits, tokenizer):
output_so_far = None
last = None
logp = 0
for i in range(logits.shape[0]):
last = _greedy(logits[i, :])
output_so_far = last if output_so_far is None else torch.cat((output_so_far, last), dim=0)
logp += logits[i, :].log_softmax(-1)[last.item()].item()
nll = -logp
text = tokenizer.decode(output_so_far.tolist())
text = text.replace('\n', ' ')
return text, nll, output_so_far
model_size = "gpt2"
tokenizer = GPT2Tokenizer.from_pretrained(model_size)
model = GPT2LMHeadModel.from_pretrained(model_size, output_hidden_states=True)
model.to('cuda')
model.eval()
input_ids = tokenizer.encode("The dog", return_tensors="pt").to('cuda')
input_one_hot = one_hot(input_ids, dimension=tokenizer.vocab_size)
def decode(model, length, temperature, device):
'''
GPT2 decoding via dense representations (no arg-max)
'''
past = None
inputs_embeds = None
logits_so_far = None
for i in range(length):
if past is None:
# inputs_embeds = model.transformer.wte(input_ids)
inputs_embeds = embed_inputs(model.get_input_embeddings(), input_one_hot.type(torch.FloatTensor)/ temperature, device='cuda', print_entropy=True)
model_outputs = model(past_key_values=past, inputs_embeds=inputs_embeds)
logits = model_outputs.logits
past = model_outputs.past_key_values
logits = logits[:, -1, :] / temperature
logits = logits.unsqueeze(1)
logits_so_far = logits if logits_so_far is None else torch.cat((logits_so_far, logits), dim=1)
inputs_embeds = embed_inputs(model.get_input_embeddings(), logits, device=device)
return logits_so_far
for temperature in [0.001, 0.01, 0.1, 0.2, 0.3, 1, 5]:
print(f" ------- \n * temperature: {temperature}")
logits_so_far = decode(model, 100, temperature, 'cuda')
text, nll, _ = get_text_from_logits(logits_so_far[0, :, :], tokenizer)
print(text)
# should see:
# -------
# * temperature: 0.001
# tensor(nan)
# was found in a field near the intersection of West and West Streets. The dog was taken to the hospital with a broken collarbone and a broken leg. The dog was taken to the hospital with a broken collarbone and a broken leg. The dog was taken to the hospital with a broken collarbone and a broken leg. The dog was taken to the hospital with a broken collarbone and a broken leg. The dog was taken to the hospital with
# -------
# * temperature: 0.01
# tensor(3.8029e-37)
# was found in a field near the intersection of West and West Streets. The dog was taken to the hospital with a broken collarbone and a broken leg. The dog was taken to the hospital with a broken collarbone and a broken leg. The dog was taken to the hospital with a broken collarbone and a broken leg. The dog was taken to the hospital with a broken collarbone and a broken leg. The dog was taken to the hospital with
# -------
# * temperature: 0.1
# tensor(16.2817)
# the dog's head, and the dog's head is the dog's head. (2) If the dog's head is the dog's head, and the dog's head is the dog's head, the dog's head is the dog's head. (3) If the dog's head is the dog's head, and the dog's head is the dog's head, the dog's head is the dog's head. (4) If the dog's head
# -------
# * temperature: 0.3
# tensor(21.6467)
# G-1-1-1-1-1-1-1-1-1-1-1-1-1-1-1-1-1-1-1-1-1-1-1-1-1-1-1-1-1-1-1-1-1-1-1-1-1-1-1-1-1-1-1-1-1-1-1-1-1-
# -------
# * temperature: 1
# tensor(21.6488)
# G. J S I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I I
# -------
# * temperature: 5
# tensor(21.6482)
# G J ( ( ( ( ( ( ( ( ( :: :: :: :: The The The The The The The The The The The The The The January January January January January January January January January January January ― ― ― ― ― ― ― ― ― ― ― ― ― ― ― ― ― ― ― ― ― ― ― ― ― ― ― ― ― ― ― ― ― ― ― ― ― ― ― ― ― ― ― ― ― ― ― ― ― ― ― ― ― ― ― ― ― ― ― ―
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment