Skip to content

Instantly share code, notes, and snippets.

@zeryx
Created December 9, 2020 16:54
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save zeryx/c3281f714aa866f969d6ce79dfd77edd to your computer and use it in GitHub Desktop.
Save zeryx/c3281f714aa866f969d6ce79dfd77edd to your computer and use it in GitHub Desktop.
finetuning a GPT-2 model to handle a character list
from transformers import AutoTokenizer, AutoModelForCausalLM
from transformers import AdamW
from random import choice
from torch.nn import functional as F
import torch
tokenizer = AutoTokenizer.from_pretrained("gpt2")
model = AutoModelForCausalLM.from_pretrained("gpt2").to('cuda')
if __name__ == "__main__":
no_decay = ['bias', 'LayerNorm.weight']
optimizer_grouped_parameters = [
{'params': [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)],
'weight_decay': 0.01},
{'params': [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}
]
optimizer = AdamW(optimizer_grouped_parameters, lr=1e-5)
background = "James.\nJohn.\nJacob"
full_text = "John turned to James, \"Hey James, what do you think about Jacob?\"Well, we'll just have to wait and see. Jacob just might make it."
background_encoding_ids = tokenizer(background, return_tensors='pt')['input_ids'].to("cuda")
text_encoding = tokenizer(full_text, return_tensors='pt')['input_ids'].to("cuda")
document_length = range(2, text_encoding.shape[-1]-1)
for i in range(1000):
optimizer.zero_grad()
index = choice(document_length)
true_index = text_encoding[:, index]
input_ids = text_encoding[:, 1:index]
payload = torch.cat((background_encoding_ids, input_ids), dim=1)
next_token_logit = model(payload).logits[:, -1, :]
probs = F.softmax(next_token_logit, dim=-1)
_, predicted_token_index = torch.max(probs, dim=1)
loss = F.cross_entropy(probs, true_index)
real_word = tokenizer.decode(true_index.detach().tolist()[0])
predicted_word = tokenizer.decode(predicted_token_index.detach().tolist()[0])
print(f"loss: {loss.cpu().detach()}\treal word: {real_word}\tpredicted word: {predicted_word}\n")
loss.backward()
optimizer.step()
# if __name__ == "__main__":
# print("ready for processing...\n")
# while True:
# data = input()
# input_ids = tokenizer.encode(data, return_tensors='pt').to('cuda')
# for _ in range(50):
# next_token_logits = model(input_ids).logits[:, -1, :]
# filtered_next_token_logits = top_k_top_p_filtering(next_token_logits.clone(), top_k=500, top_p=0.95)
# probs = F.softmax(filtered_next_token_logits, dim=-1)
# next_token = torch.multinomial(probs, num_samples=1)
# input_ids = torch.cat([input_ids, next_token], dim=-1)
# result = tokenizer.decode(input_ids.tolist()[0])
# print(result)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment