Last active
October 25, 2020 15:00
-
-
Save robgon-art/5280d8aee37a02c045cc2c3a59d6b54e to your computer and use it in GitHub Desktop.
Train the ai8ball
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from tqdm import tqdm | |
batch_size = 8 | |
epochs = 3 | |
grad_acc_steps = 4 | |
train_loss_values = [] | |
dev_acc_values = [] | |
for _ in tqdm(range(epochs), desc="Epoch"): | |
# Training | |
epoch_train_loss = 0 | |
model.train() | |
model.zero_grad() | |
for step, batch in enumerate(train_dataloader): | |
input_ids = batch[0].to(device) | |
attention_masks = batch[1].to(device) | |
labels = batch[2].to(device) | |
outputs = model(input_ids, token_type_ids=None, | |
attention_mask=attention_masks, labels=labels) | |
loss = outputs[0] | |
loss = loss / grad_acc_steps | |
epoch_train_loss += loss.item() | |
loss.backward() | |
if (step+1) % grad_acc_steps == 0: | |
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) | |
optimizer.step() | |
model.zero_grad() | |
epoch_train_loss = epoch_train_loss / len(train_dataloader) | |
train_loss_values.append(epoch_train_loss) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment