Skip to content

Instantly share code, notes, and snippets.

@bh1995
Created December 30, 2020 17:35
Show Gist options
  • Save bh1995/ba54126f9218b6c330e4158632168c84 to your computer and use it in GitHub Desktop.
Save bh1995/ba54126f9218b6c330e4158632168c84 to your computer and use it in GitHub Desktop.
# Perform training loop for n epochs
loss_list = []
n_epochs = 10
model.train()
for epoch in tqdm(range(n_epochs)):
loss_epoch = []
iteration=1
for images,targets in tqdm(data_loader_train):
images = list(image.to(device) for image in images)
targets = [{k: v.to(device) for k, v in t.items()} for t in targets]
optimizer.zero_grad()
model=model.double()
loss_dict = model(images, targets)
losses = sum(loss for loss in loss_dict.values())
losses.backward()
optimizer.step()
# print('loss:', losses.item())
# loss_epoch.append(losses.item())
loss_epoch.append(losses.item())
# Plot loss every 10th iteration
plt.plot(list(range(iteration)), loss_epoch)
plt.xlabel("Iteration")
plt.ylabel("Loss")
plt.show()
iteration+=1
loss_epoch_mean = np.mean(loss_epoch)
loss_list.append(loss_epoch_mean)
# loss_list.append(loss_epoch_mean)
print("Average loss for epoch = {:.4f} ".format(loss_epoch_mean))
# Save model
model_nr = latest_model() + 1 # keep track of which model nr was just trained.
save_path = 'PUT YOUR SAVE PATH HERE'+str(model_nr)
torch.save(model.state_dict(), save_path)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment