Skip to content

Instantly share code, notes, and snippets.

@zaemyung
Created October 17, 2023 08:20
Show Gist options
  • Save zaemyung/cfef11ed538157758799f85155205a82 to your computer and use it in GitHub Desktop.
Save zaemyung/cfef11ed538157758799f85155205a82 to your computer and use it in GitHub Desktop.
custom trainer example
from sklearn.metrics import confusion_matrix
class CustomTrainer(Trainer):
def _inner_training_loop(
self, batch_size=None, args=None, resume_from_checkpoint=None, trial=None, \
ignore_keys_for_eval=None):
number_of_epochs = args.num_train_epochs
train_loss = []
train_acc = []
eval_loss = []
eval_acc = []
times_per_epoch = []
times_per_inference = []
criterion = torch.nn.CrossEntropyLoss().to(device)
self.optimizer = AdamW(model.parameters(), lr=args.learning_rate)
self.scheduler = torch.optim.lr_scheduler.StepLR(self.optimizer, 1, gamma = 0.9)
train_dataloader = self.get_train_dataloader()
eval_dataloader = self.get_eval_dataloader()
max_steps = math.ceil(len(train_dataloader) * args.num_train_epochs)
for epoch in range(number_of_epochs):
self.model.train()
self.model.zero_grad()
train_loss_per_epoch = 0
train_acc_per_epoch = 0
with tqdm(train_dataloader, unit="batch") as training_epoch:
training_epoch.set_description(f"Training Epoch {epoch}")
starttime_epoch = time.time()
for step, inputs in enumerate(training_epoch):
inputs = inputs.to(device)
labels = inputs['labels']
self.optimizer.zero_grad()
start_inference = time.time()
output = model(**inputs)
end_inference = time.time()
times_per_inference.append(end_inference-start_inference)
loss = criterion(output.logits, labels)
train_loss_per_epoch+=loss.item()
loss.backward()
self.optimizer.step()
train_acc_per_epoch += (output['logits'].argmax(1) == labels).sum().item()
endtime_epoch = time.time()
times_per_epoch.append(endtime_epoch-starttime_epoch)
self.scheduler.step()
train_loss_per_epoch /= len(train_dataloader)
train_acc_per_epoch /= (len(train_dataloader)*batch_size)
eval_loss_per_epoch = 0
eval_acc_per_epoch = 0
with tqdm(eval_dataloader, unit="batch") as eval_epoch:
eval_epoch.set_description(f"Evaluation Epoch {epoch}")
for step, inputs in enumerate(eval_epoch):
inputs = inputs.to(device)
labels = inputs['labels']
output = model(**inputs)
loss = criterion(output.logits, labels)
eval_loss_per_epoch+=loss.item()
loss.backward()
eval_acc_per_epoch += (output['logits'].argmax(1) == labels).sum().item()
eval_loss_per_epoch /= len(eval_dataloader)
eval_acc_per_epoch /= (len(eval_dataloader)*batch_size)
print(f'\tTrain Loss: {train_loss_per_epoch} | Train Acc: {train_acc_per_epoch*100.0}%')
print(f'\tEval Loss: {eval_loss_per_epoch} | eval Acc: {eval_acc_per_epoch*100.0}%')
train_loss.append(train_loss_per_epoch)
train_acc.append(train_acc_per_epoch)
eval_loss.append(eval_loss_per_epoch)
eval_acc.append(eval_acc_per_epoch)
model.save_pretrained(f'./model_epoch_{epoch}')
return train_loss, train_acc, eval_loss, eval_acc, times_per_epoch, times_per_inference
trainer = CustomTrainer(
model=model,
args=training_args,
train_dataset=train_set,
eval_dataset=eval_set,
data_collator = data_collator)
train_loss, train_acc, eval_loss, eval_acc, times_per_epoch, times_per_inference = trainer.train()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment