Created
December 9, 2020 16:54
-
-
Save zeryx/c3281f714aa866f969d6ce79dfd77edd to your computer and use it in GitHub Desktop.
finetuning a GPT-2 model to handle a character list
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from transformers import AutoTokenizer, AutoModelForCausalLM | |
from transformers import AdamW | |
from random import choice | |
from torch.nn import functional as F | |
import torch | |
tokenizer = AutoTokenizer.from_pretrained("gpt2") | |
model = AutoModelForCausalLM.from_pretrained("gpt2").to('cuda') | |
if __name__ == "__main__": | |
no_decay = ['bias', 'LayerNorm.weight'] | |
optimizer_grouped_parameters = [ | |
{'params': [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)], | |
'weight_decay': 0.01}, | |
{'params': [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)], 'weight_decay': 0.0} | |
] | |
optimizer = AdamW(optimizer_grouped_parameters, lr=1e-5) | |
background = "James.\nJohn.\nJacob" | |
full_text = "John turned to James, \"Hey James, what do you think about Jacob?\"Well, we'll just have to wait and see. Jacob just might make it." | |
background_encoding_ids = tokenizer(background, return_tensors='pt')['input_ids'].to("cuda") | |
text_encoding = tokenizer(full_text, return_tensors='pt')['input_ids'].to("cuda") | |
document_length = range(2, text_encoding.shape[-1]-1) | |
for i in range(1000): | |
optimizer.zero_grad() | |
index = choice(document_length) | |
true_index = text_encoding[:, index] | |
input_ids = text_encoding[:, 1:index] | |
payload = torch.cat((background_encoding_ids, input_ids), dim=1) | |
next_token_logit = model(payload).logits[:, -1, :] | |
probs = F.softmax(next_token_logit, dim=-1) | |
_, predicted_token_index = torch.max(probs, dim=1) | |
loss = F.cross_entropy(probs, true_index) | |
real_word = tokenizer.decode(true_index.detach().tolist()[0]) | |
predicted_word = tokenizer.decode(predicted_token_index.detach().tolist()[0]) | |
print(f"loss: {loss.cpu().detach()}\treal word: {real_word}\tpredicted word: {predicted_word}\n") | |
loss.backward() | |
optimizer.step() | |
# if __name__ == "__main__": | |
# print("ready for processing...\n") | |
# while True: | |
# data = input() | |
# input_ids = tokenizer.encode(data, return_tensors='pt').to('cuda') | |
# for _ in range(50): | |
# next_token_logits = model(input_ids).logits[:, -1, :] | |
# filtered_next_token_logits = top_k_top_p_filtering(next_token_logits.clone(), top_k=500, top_p=0.95) | |
# probs = F.softmax(filtered_next_token_logits, dim=-1) | |
# next_token = torch.multinomial(probs, num_samples=1) | |
# input_ids = torch.cat([input_ids, next_token], dim=-1) | |
# result = tokenizer.decode(input_ids.tolist()[0]) | |
# print(result) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment