Skip to content

Instantly share code, notes, and snippets.

@macleginn
Last active June 3, 2022 13:54
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save macleginn/561e4d6ed7928b24e8626f8d520cc963 to your computer and use it in GitHub Desktop.
Save macleginn/561e4d6ed7928b24e8626f8d520cc963 to your computer and use it in GitHub Desktop.
Training and evaluation code for a simple model that predicts a token removed from a sentence
import json
from math import ceil
from random import shuffle
import torch
import torch.nn as nn
from transformers import AutoTokenizer, AutoModel
from transformers import AdamW, get_scheduler
from tqdm.auto import tqdm
class ClassificationHead(nn.Module):
def __init__(self, n_classes, input_size=768):
super().__init__()
self.linear1 = nn.Linear(input_size, 768)
self.linear2 = nn.Linear(768, n_classes)
def forward(self, x):
x = self.linear1(x)
x = torch.tanh(x)
return self.linear2(x)
class ClassificationHeadSimple(nn.Module):
'''
A simple linear classifier that converts token embeddings
to class scores.
'''
def __init__(self, n_classes, input_size=768):
super().__init__()
self.linear = nn.Linear(input_size, n_classes)
def forward(self, x):
return self.linear(x)
def throw_away_token(sentence_batch, tokeniser):
throwaway_token_ids = [None for _ in sentence_batch]
subword_dict = tokeniser(sentence_batch, return_tensors='pt', padding=True, truncation=True)
for i in range(len(sentence_batch)):
n_tokens = subword_dict['attention_mask'][i].sum().item()
throwaway_idx = torch.randint(low=1, high=n_tokens, size=(1,1)).item()
throwaway_token_ids[i] = subword_dict['input_ids'][i][throwaway_idx]
zero_tensor = torch.tensor([0])
subword_dict['input_ids'][i] = torch.cat((
subword_dict['input_ids'][i][:throwaway_idx],
subword_dict['input_ids'][i][throwaway_idx+1:],
zero_tensor))
subword_dict['token_type_ids'][i] = torch.cat((
subword_dict['token_type_ids'][i][:throwaway_idx],
subword_dict['token_type_ids'][i][throwaway_idx+1:],
zero_tensor))
subword_dict['attention_mask'][i] = torch.cat((
subword_dict['attention_mask'][i][:throwaway_idx],
subword_dict['attention_mask'][i][throwaway_idx+1:],
zero_tensor))
return throwaway_token_ids, subword_dict
if __name__ == '__main__':
model_name = "bert-base-uncased"
tokenizer = AutoTokenizer.from_pretrained(model_name)
bert_model = AutoModel.from_pretrained(model_name)
bert_model.cuda()
bert_model = nn.DataParallel(bert_model)
n_classes = tokenizer.vocab_size
classification_head = ClassificationHeadSimple(n_classes)
classification_head.cuda()
classification_head = nn.DataParallel(classification_head)
optimizer = AdamW(list(classification_head.parameters()), lr=1e-5)
with open('../data/hansard_short_sentences.json', 'r', encoding='utf-8') as inp:
data_all = json.load(inp)
indices_permuted = torch.randperm(len(data_all))
data_dev = [data_all[i].lower() for i in indices_permuted[:1000]]
data_test = [data_all[i].lower() for i in indices_permuted[1000:2000]]
data_train = [data_all[i].lower() for i in indices_permuted[2000:22000]]
n_epochs = 5
n_training_steps = n_epochs * len(data_train)
lr_scheduler = get_scheduler(
'linear',
optimizer=optimizer,
num_warmup_steps=0,
num_training_steps=n_training_steps
)
loss_function = nn.CrossEntropyLoss()
batch_size = 128
min_dev_loss = float('inf')
for epoch in range(n_epochs):
# Train
bert_model.train()
epoch_train_losses = []
n_steps_train = ceil(len(data_train) / batch_size)
for batch_n in tqdm(range(n_steps_train), desc=f'Epoch {epoch+1}, train', leave=False):
batch_sentences = data_train[batch_size * batch_n:
batch_size * (batch_n + 1)]
gold_labels, inputs = throw_away_token(batch_sentences, tokenizer)
mbert_inputs = {
k: v.cuda() for k, v in inputs.items()
}
with torch.no_grad():
mbert_outputs = bert_model(**mbert_inputs).last_hidden_state
cls_embeddings = []
for i in range(len(batch_sentences)):
cls_embeddings.append(mbert_outputs[i, 0, :])
logits = classification_head(torch.vstack(cls_embeddings))
loss = loss_function(logits, torch.tensor(gold_labels).cuda())
loss.backward()
epoch_train_losses.append(loss.item())
optimizer.step()
lr_scheduler.step()
optimizer.zero_grad()
print(
f'Epoch {epoch+1} train loss: {torch.tensor(epoch_train_losses).mean()}')
# Evaluate
bert_model.eval()
epoch_dev_losses = []
n_steps_dev = ceil(len(data_dev) / batch_size)
hits = 0
misses = 0
for batch_n in tqdm(range(n_steps_dev), desc=f'Epoch {epoch+1}, dev', leave=False):
batch_sentences = data_train[batch_size * batch_n:
batch_size * (batch_n + 1)]
gold_labels, inputs = throw_away_token(batch_sentences, tokenizer)
mbert_inputs = {
k: v.cuda() for k, v in inputs.items()
}
with torch.no_grad():
mbert_outputs = bert_model(**mbert_inputs).last_hidden_state
cls_embeddings = []
for i in range(len(batch_sentences)):
cls_embeddings.append(mbert_outputs[i, 0, :])
logits = classification_head(torch.vstack(cls_embeddings))
loss = loss_function(logits, torch.tensor(gold_labels).cuda())
epoch_dev_losses.append(loss.item())
classes = torch.argmax(logits, dim=1)
for guessed, gold in zip(classes, gold_labels):
if guessed == gold:
hits += 1
else:
misses += 1
print(
f'Epoch {epoch+1} dev loss: {torch.tensor(epoch_dev_losses).mean()}')
print(
f'Epoch {epoch+1} dev accuracy: {round(hits / (hits + misses), 2)}')
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment