Skip to content

Instantly share code, notes, and snippets.

@itsuncheng
Last active June 12, 2020 10:32
Show Gist options
  • Save itsuncheng/07f761498dfdeb6d781bb6e793cf02b6 to your computer and use it in GitHub Desktop.
Save itsuncheng/07f761498dfdeb6d781bb6e793cf02b6 to your computer and use it in GitHub Desktop.
# Training Function
def train(model,
optimizer,
criterion = nn.BCELoss(),
train_loader = train_iter,
valid_loader = valid_iter,
num_epochs = 5,
eval_every = len(train_iter) // 2,
file_path = destination_folder,
best_valid_loss = float("Inf")):
# initialize running values
running_loss = 0.0
valid_running_loss = 0.0
global_step = 0
train_loss_list = []
valid_loss_list = []
global_steps_list = []
# training loop
model.train()
for epoch in range(num_epochs):
for (labels, title, text, titletext), _ in train_loader:
labels = labels.type(torch.LongTensor)
labels = labels.to(device)
titletext = titletext.type(torch.LongTensor)
titletext = titletext.to(device)
output = model(titletext, labels)
loss, _ = output
optimizer.zero_grad()
loss.backward()
optimizer.step()
# update running values
running_loss += loss.item()
global_step += 1
# evaluation step
if global_step % eval_every == 0:
model.eval()
with torch.no_grad():
# validation loop
for (labels, title, text, titletext), _ in valid_loader:
labels = labels.type(torch.LongTensor)
labels = labels.to(device)
titletext = titletext.type(torch.LongTensor)
titletext = titletext.to(device)
output = model(titletext, labels)
loss, _ = output
valid_running_loss += loss.item()
# evaluation
average_train_loss = running_loss / eval_every
average_valid_loss = valid_running_loss / len(valid_loader)
train_loss_list.append(average_train_loss)
valid_loss_list.append(average_valid_loss)
global_steps_list.append(global_step)
# resetting running values
running_loss = 0.0
valid_running_loss = 0.0
model.train()
# print progress
print('Epoch [{}/{}], Step [{}/{}], Train Loss: {:.4f}, Valid Loss: {:.4f}'
.format(epoch+1, num_epochs, global_step, num_epochs*len(train_loader),
average_train_loss, average_valid_loss))
# checkpoint
if best_valid_loss > average_valid_loss:
best_valid_loss = average_valid_loss
save_checkpoint(file_path + '/' + 'model.pt', model, best_valid_loss)
save_metrics(file_path + '/' + 'metrics.pt', train_loss_list, valid_loss_list, global_steps_list)
save_metrics(file_path + '/' + 'metrics.pt', train_loss_list, valid_loss_list, global_steps_list)
print('Finished Training!')
model = BERT().to(device)
optimizer = optim.Adam(model.parameters(), lr=2e-5)
train(model=model, optimizer=optimizer)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment