Skip to content

Instantly share code, notes, and snippets.

@tuetschek
Created January 23, 2024 23:05
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save tuetschek/da155533c0dde8c6b1916e9fcc7527b1 to your computer and use it in GitHub Desktop.
Save tuetschek/da155533c0dde8c6b1916e9fcc7527b1 to your computer and use it in GitHub Desktop.
GPT2DoubleHeadsModel used for actual classification, not choice selection
import torch
import transformers
import tqdm
import copy
import numpy as np
from logzero import logger
# some tiny data -- sentiment classification + LM
DATA = [
[{'text': 'This is good . [CLS]',
'class': 1},
{'text': 'This is bad . [CLS]',
'class': 0}],
[{'text': 'I liked it . [CLS]',
'class': 1},
{'text': 'I hated it . [CLS]',
'class': 0}],
[{'text': 'It was great . [CLS]',
'class': 1},
{'text': 'It was bad . [CLS]',
'class': 0},]
]
class GPT2DoubleHeadsSC(transformers.GPT2DoubleHeadsModel):
def __init__(self, config):
transformers.GPT2PreTrainedModel.__init__(self, config)
config.num_labels = 2 # XXX This is the only thing changed w.r.t. GPT2DoubleHeadsModel
self.transformer = transformers.GPT2Model(config)
self.lm_head = torch.nn.Linear(config.n_embd, config.vocab_size, bias=False)
self.multiple_choice_head = transformers.modeling_utils.SequenceSummary(config)
# Model parallel
self.model_parallel = False
self.device_map = None
# Initialize weights and apply final processing
self.post_init()
class DataLoader:
def __init__(self, data, tokenizer):
self.data = []
self.tokenizer = tokenizer
for batch in data:
tokenized = [self.tokenizer(i['text']) for i in batch]
self.data.append({'input_ids': torch.tensor([i['input_ids'] for i in tokenized]),
'labels': torch.tensor([i['input_ids'] for i in tokenized]),
'attention_mask': torch.tensor([i['attention_mask'] for i in tokenized]),
'mc_token_ids': torch.tensor([i['input_ids'].index(tokenizer.cls_token_id) for i in tokenized]),
'mc_labels': torch.tensor([i['class'] for i in batch])})
def __iter__(self):
for batch in self.data:
yield copy.copy(batch)
def __len__(self):
return len(self.data)
class Trainer:
def __init__(self,
model,
train_data_loader,
epochs: int,
optimizer,
scheduler,
logger=logger):
self.model = model
self.device = model.device
self.train_data_loader = train_data_loader
self.epochs = epochs
self.optimizer = optimizer
self.scheduler = scheduler
self.logger = logger
def train(self):
self.logger.info('Starting training...')
for epoch in range(self.epochs):
self.logger.info(f'====== Epoch {epoch}/{self.epochs} Training ======')
self.model.train()
ep_loss = 0
for step, batch in enumerate(tqdm.tqdm(self.train_data_loader)):
output = self.model(**batch)
# Backpropagate loss
loss = output.loss + output.mc_loss
ep_loss += loss.item()
loss.backward()
# Optimizer and scheduler steps
self.optimizer.step()
self.scheduler.step()
self.optimizer.zero_grad()
self.logger.debug(f'Epoch loss: {loss}')
def test_training():
transformers.set_seed(42)
tokenizer = transformers.GPT2Tokenizer.from_pretrained("distilgpt2")
tokenizer.add_special_tokens({"cls_token": "[CLS]"})
model = GPT2DoubleHeadsSC.from_pretrained('distilgpt2')
model.resize_token_embeddings(len(tokenizer))
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4)
scheduler = transformers.get_constant_schedule(optimizer)
loader = DataLoader(DATA, tokenizer)
trainer = Trainer(model, loader, 20, optimizer, scheduler)
# overfit the model on this data
trainer.train()
# testing that the model is really overfit
model.eval()
toks_total, toks_corr, cls_total, cls_corr = 0, 0, 0, 0
for batch in loader:
with torch.no_grad():
output = model(**{'attention_mask': batch['attention_mask'], 'input_ids': batch['input_ids']})
toks_preds = batch['input_ids'].numpy()[:, 1:] == torch.argmax(output.logits, dim=-1).numpy()[:, :-1]
cls_preds = batch['mc_labels'].numpy() == torch.argmax(output.mc_logits, dim=-1).numpy()
toks_corr += np.sum(toks_preds)
toks_total += np.prod(toks_preds.shape)
cls_corr += np.sum(cls_preds)
cls_total += np.prod(cls_preds.shape)
logger.info(f'Token accuracy: {toks_corr / toks_total}')
logger.info(f'Classification accuracy: {cls_corr / cls_total}')
if __name__ == '__main__':
test_training()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment