Skip to content

Instantly share code, notes, and snippets.

@gauravbansal98
Created May 14, 2020 11:45
Show Gist options
  • Save gauravbansal98/e611e8e1eb198af2397cc9d7568b1889 to your computer and use it in GitHub Desktop.
Save gauravbansal98/e611e8e1eb198af2397cc9d7568b1889 to your computer and use it in GitHub Desktop.
device = 'cuda' if torch.cuda.is_available() else 'cpu'
model = CaptionModel(vocab_size).to(device)
loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
if args.checkpoint != None:
print("Loading the checkpoint")
model.load_state_dict(torch.load(args.checkpoint))
print("Number of epochs ", args.num_epochs)
for epoch in range(args.num_epochs):
generator = data_generator(train_descriptions, train_features, tokenizer, max_length, vocab_size)
test_generator = data_generator(test_descriptions, test_features, tokenizer, max_length, vocab_size)
tr_loss, test_loss = 0, 0
training_examples, test_examples = 0, 0
model.train()
for batch, data in enumerate(generator):
image, caption, target_word = data
out = model(image.to(device), caption.type(torch.LongTensor).to(device))
loss = loss_fn(out, torch.from_numpy(np.array(target_word)).to(device))
tr_loss += loss.item()
model.zero_grad()
loss.backward()
optimizer.step()
training_examples += image.size(0)
if (batch+1)%200 == 0:
print("Epoch: {}, Batch: {}, loss: {}, avg loss: {}".format(epoch+1, batch+1, loss.item(), tr_loss/(training_examples)))
if (batch+1)%400 == 0:
model.eval().cpu()
ckpt_model_path = os.path.join('results', 'ckpt_epoch_{}_batch_{}.pth'.format(epoch+1, batch+1))
torch.save(model.state_dict(), ckpt_model_path)
model.to(device).train()
model.eval()
for test_batch, data in enumerate(test_generator):
image, caption, target_word = data
out = model(image.to(device), caption.type(torch.LongTensor).to(device))
loss = loss_fn(out, torch.from_numpy(np.array(target_word)).to(device))
test_loss += loss.item()
test_examples += image.size(0)
print("Epoch {}, Training loss: {}, Test loss: {}".format(epoch+1, tr_loss/training_examples, test_loss/test_examples))
print("Training Complete")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment