Skip to content

Instantly share code, notes, and snippets.

@prateekjoshi565
Created July 28, 2020 07:53
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save prateekjoshi565/072ba7be45bcd8573784ce52b3e8c578 to your computer and use it in GitHub Desktop.
Save prateekjoshi565/072ba7be45bcd8573784ce52b3e8c578 to your computer and use it in GitHub Desktop.
def train(net, epochs=10, batch_size=32, lr=0.001, clip=1, print_every=32):
# optimizer
opt = torch.optim.Adam(net.parameters(), lr=lr)
# loss
criterion = nn.CrossEntropyLoss()
# push model to GPU
net.cuda()
counter = 0
net.train()
for e in range(epochs):
# initialize hidden state
h = net.init_hidden(batch_size)
for x, y in get_batches(x_int, y_int, batch_size):
counter+= 1
# convert numpy arrays to PyTorch arrays
inputs, targets = torch.from_numpy(x), torch.from_numpy(y)
# push tensors to GPU
inputs, targets = inputs.cuda(), targets.cuda()
# detach hidden states
h = tuple([each.data for each in h])
# zero accumulated gradients
net.zero_grad()
# get the output from the model
output, h = net(inputs, h)
# calculate the loss and perform backprop
loss = criterion(output, targets.view(-1))
# back-propagate error
loss.backward()
# `clip_grad_norm` helps prevent the exploding gradient problem in RNNs / LSTMs.
nn.utils.clip_grad_norm_(net.parameters(), clip)
# update weigths
opt.step()
if counter % print_every == 0:
print("Epoch: {}/{}...".format(e+1, epochs),
"Step: {}...".format(counter))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment