Skip to content

Instantly share code, notes, and snippets.

@robgon-art
Last active October 25, 2020 15:00
Show Gist options
  • Save robgon-art/5280d8aee37a02c045cc2c3a59d6b54e to your computer and use it in GitHub Desktop.
Save robgon-art/5280d8aee37a02c045cc2c3a59d6b54e to your computer and use it in GitHub Desktop.
Train the ai8ball
from tqdm import tqdm
batch_size = 8
epochs = 3
grad_acc_steps = 4
train_loss_values = []
dev_acc_values = []
for _ in tqdm(range(epochs), desc="Epoch"):
# Training
epoch_train_loss = 0
model.train()
model.zero_grad()
for step, batch in enumerate(train_dataloader):
input_ids = batch[0].to(device)
attention_masks = batch[1].to(device)
labels = batch[2].to(device)
outputs = model(input_ids, token_type_ids=None,
attention_mask=attention_masks, labels=labels)
loss = outputs[0]
loss = loss / grad_acc_steps
epoch_train_loss += loss.item()
loss.backward()
if (step+1) % grad_acc_steps == 0:
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
optimizer.step()
model.zero_grad()
epoch_train_loss = epoch_train_loss / len(train_dataloader)
train_loss_values.append(epoch_train_loss)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment