Skip to content

Instantly share code, notes, and snippets.

@akshay-3apr
Last active September 6, 2022 11:21
Show Gist options
  • Save akshay-3apr/18b05303bdf698770a735b68ffce570a to your computer and use it in GitHub Desktop.
Save akshay-3apr/18b05303bdf698770a735b68ffce570a to your computer and use it in GitHub Desktop.
## Store our loss and accuracy for plotting
train_loss_set = []
learning_rate = []
# Gradients gets accumulated by default
model.zero_grad()
# tnrange is a tqdm wrapper around the normal python range
for _ in tnrange(1,epochs+1,desc='Epoch'):
print("<" + "="*22 + F" Epoch {_} "+ "="*22 + ">")
# Calculate total loss for this epoch
batch_loss = 0
for step, batch in enumerate(train_dataloader):
# Set our model to training mode (as opposed to evaluation mode)
model.train()
# Add batch to GPU
batch = tuple(t.to(device) for t in batch)
# Unpack the inputs from our dataloader
b_input_ids, b_input_mask, b_labels = batch
# Forward pass
outputs = model(b_input_ids, token_type_ids=None, attention_mask=b_input_mask, labels=b_labels)
loss = outputs[0]
# Backward pass
loss.backward()
# Clip the norm of the gradients to 1.0
# Gradient clipping is not in AdamW anymore
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
# Update parameters and take a step using the computed gradient
optimizer.step()
# Update learning rate schedule
scheduler.step()
# Clear the previous accumulated gradients
optimizer.zero_grad()
# Update tracking variables
batch_loss += loss.item()
# Calculate the average loss over the training data.
avg_train_loss = batch_loss / len(train_dataloader)
#store the current learning rate
for param_group in optimizer.param_groups:
print("\n\tCurrent Learning rate: ",param_group['lr'])
learning_rate.append(param_group['lr'])
train_loss_set.append(avg_train_loss)
print(F'\n\tAverage Training loss: {avg_train_loss}')
# Validation
# Put model in evaluation mode to evaluate loss on the validation set
model.eval()
# Tracking variables
eval_accuracy,eval_mcc_accuracy,nb_eval_steps = 0, 0, 0
# Evaluate data for one epoch
for batch in validation_dataloader:
# Add batch to GPU
batch = tuple(t.to(device) for t in batch)
# Unpack the inputs from our dataloader
b_input_ids, b_input_mask, b_labels = batch
# Telling the model not to compute or store gradients, saving memory and speeding up validation
with torch.no_grad():
# Forward pass, calculate logit predictions
logits = model(b_input_ids, token_type_ids=None, attention_mask=b_input_mask)
# Move logits and labels to CPU
logits = logits[0].to('cpu').numpy()
label_ids = b_labels.to('cpu').numpy()
pred_flat = np.argmax(logits, axis=1).flatten()
labels_flat = label_ids.flatten()
tmp_eval_accuracy = accuracy_score(labels_flat,pred_flat)
tmp_eval_mcc_accuracy = matthews_corrcoef(labels_flat, pred_flat)
eval_accuracy += tmp_eval_accuracy
eval_mcc_accuracy += tmp_eval_mcc_accuracy
nb_eval_steps += 1
print(F'\n\tValidation Accuracy: {eval_accuracy/nb_eval_steps}')
print(F'\n\tValidation MCC Accuracy: {eval_mcc_accuracy/nb_eval_steps}')
@venkatasg
Copy link

Thanks for your medium post and your code! It's my first time fine tuning BERT, and writing my code following your instructions really helped me understand :)

I found a trick for speeding up training from [here](speedup by truncating unused part by artemisart · Pull Request #66 · huggingface/transformers) - it involves truncating the unused parts of each batch based on the maximum length of the sequence. You might consider replacing lines 18-22 with these

# Unpack the inputs from our dataloader
b_input_ids, b_input_mask, b_labels = batch

# truncate the batch to maximum length for a speedup
max_length = (b_input_mask != 0).max(0)[0].nonzero()[-1].item()

if max_length < input_ids.shape[1]:
    b_input_ids = b_input_ids[:, :max_length].to(device)
    b_input_mask = b_input_mask[:, :max_length].to(device)
else:
    b_input_ids = b_input_ids.to(device)
    b_input_mask = b_input_mask.to(device)

b_labels = b_labels.to(device)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment