Skip to content

Instantly share code, notes, and snippets.

@vwxyzjn
Created February 1, 2024 22:16
Show Gist options
  • Save vwxyzjn/f1e2b6ee7b8478aa3190c4f4140ac6ca to your computer and use it in GitHub Desktop.
Save vwxyzjn/f1e2b6ee7b8478aa3190c4f4140ac6ca to your computer and use it in GitHub Desktop.
import torch
import transformers
import torch.nn.functional as F
tokenizer = transformers.AutoTokenizer.from_pretrained("gpt2", padding_side="right")
tokenizer.add_special_tokens({"pad_token": "[PAD]"})
pad_id = tokenizer.pad_token_id
policy = transformers.AutoModelForCausalLM.from_pretrained("gpt2")
policy.generation_config.pad_token_id = policy.generation_config.eos_token_id
query = torch.tensor([
[pad_id, pad_id, 23073],
[pad_id, pad_id, 234],
])
temperature = 0.7
context_length = query.shape[1]
def forward(model, query_responses, tokenizer):
attention_mask = query_responses != tokenizer.pad_token_id
position_ids = attention_mask.cumsum(1) - attention_mask.long()
input_ids = torch.masked_fill(query_responses, ~attention_mask, 0)
return model(
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
return_dict=True,
output_hidden_states=True,
)
def generate_and_return_logits(lm_backbone, queries, tokenizer, generation_config):
"""generate in a way that does not affect padding tokens"""
context_length = queries.shape[1]
attention_mask = queries != tokenizer.pad_token_id
input_ids = torch.masked_fill(queries, ~attention_mask, 0)
output = lm_backbone.generate(
input_ids=input_ids,
attention_mask=attention_mask,
# position_ids=attention_mask.cumsum(1) - attention_mask.long(), # already handled in generation
generation_config=generation_config,
return_dict_in_generate=True,
output_scores=True
)
logits = torch.stack(output.scores, 1)
return torch.cat((queries, output.sequences[:, context_length:]), dim=1), logits
generation_config = transformers.GenerationConfig(
max_new_tokens=5,
min_new_tokens=5,
temperature=temperature,
top_k=0.0,
top_p=1.0,
do_sample=True,
)
query_response, logits = generate_and_return_logits(policy, query, tokenizer, generation_config)
response = query_response[:, context_length:]
all_logprob = F.log_softmax(logits, dim=-1)
logprob = torch.gather(all_logprob, 2, response.unsqueeze(-1)).squeeze(-1)
print(f"{response=}")
print(f"{logprob=}")
output = forward(policy, query_response, tokenizer)
logits = output.logits[:, context_length - 1 : -1]
logits /= temperature
all_logprob = F.log_softmax(logits, dim=-1)
logprob = torch.gather(all_logprob, 2, response.unsqueeze(-1)).squeeze(-1)
print(f"{logprob=}")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment