Skip to content

Instantly share code, notes, and snippets.

@dmesquita
Last active October 25, 2017 12:26
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save dmesquita/1688d747e26bb96334a6c4f37dd7de8b to your computer and use it in GitHub Desktop.
Save dmesquita/1688d747e26bb96334a6c4f37dd7de8b to your computer and use it in GitHub Desktop.
# Train the Model
for epoch in range(num_epochs):
total_batch = int(len(newsgroups_train.data)/batch_size)
for i in range(total_batch):
batch_x,batch_y = get_batch(newsgroups_train,i,batch_size)
articles = Variable(torch.FloatTensor(batch_x))
labels = Variable(torch.FloatTensor(batch_y))
# Forward + Backward + Optimize
optimizer.zero_grad() # zero the gradient buffer
outputs = net(articles)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
print (‘Epoch [%d/%d], Step [%d/%d], Loss: %.4f’
%(epoch+1, num_epochs, i+1, len(newsgroups_train.data)//batch_size, loss.data[0]))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment