Skip to content

Instantly share code, notes, and snippets.

Last active August 4, 2019 16:56
Show Gist options
  • Save samarth-agrawal-86/4aea7efda2e400741da53fa22300ad1f to your computer and use it in GitHub Desktop.
Save samarth-agrawal-86/4aea7efda2e400741da53fa22300ad1f to your computer and use it in GitHub Desktop.
# loss and optimization functions
criterion = nn.BCELoss()
optimizer = torch.optim.Adam(net.parameters(), lr=lr)
# training params
epochs = 4 # 3-4 is approx where I noticed the validation loss stop decreasing
counter = 0
print_every = 100
clip=5 # gradient clipping
# move model to GPU, if available
# train for some number of epochs
for e in range(epochs):
# initialize hidden state
h = net.init_hidden(batch_size)
# batch loop
for inputs, labels in train_loader:
counter += 1
inputs, labels = inputs.cuda(), labels.cuda()
# Creating new variables for the hidden state, otherwise
# we'd backprop through the entire training history
h = tuple([ for each in h])
# zero accumulated gradients
# get the output from the model
inputs = inputs.type(torch.LongTensor)
output, h = net(inputs, h)
# calculate the loss and perform backprop
loss = criterion(output.squeeze(), labels.float())
# `clip_grad_norm` helps prevent the exploding gradient problem in RNNs / LSTMs.
nn.utils.clip_grad_norm_(net.parameters(), clip)
# loss stats
if counter % print_every == 0:
# Get validation loss
val_h = net.init_hidden(batch_size)
val_losses = []
for inputs, labels in valid_loader:
# Creating new variables for the hidden state, otherwise
# we'd backprop through the entire training history
val_h = tuple([ for each in val_h])
inputs, labels = inputs.cuda(), labels.cuda()
inputs = inputs.type(torch.LongTensor)
output, val_h = net(inputs, val_h)
val_loss = criterion(output.squeeze(), labels.float())
print("Epoch: {}/{}...".format(e+1, epochs),
"Step: {}...".format(counter),
"Loss: {:.6f}...".format(loss.item()),
"Val Loss: {:.6f}".format(np.mean(val_losses)))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment