Skip to content

Instantly share code, notes, and snippets.

@samarth-agrawal-86
Last active March 26, 2020 21:50
Show Gist options
  • Save samarth-agrawal-86/2c7e1c9db9f36dbf79760a2070e406c7 to your computer and use it in GitHub Desktop.
Save samarth-agrawal-86/2c7e1c9db9f36dbf79760a2070e406c7 to your computer and use it in GitHub Desktop.
Testing on test dataset
# Get test data loss and accuracy
test_losses = [] # track loss
num_correct = 0
# init hidden state
h = net.init_hidden(batch_size)
net.eval()
# iterate over test data
for inputs, labels in test_loader:
# Creating new variables for the hidden state, otherwise
# we'd backprop through the entire training history
h = tuple([each.data for each in h])
if(train_on_gpu):
inputs, labels = inputs.cuda(), labels.cuda()
# get predicted outputs
inputs = inputs.type(torch.LongTensor)
output, h = net(inputs, h)
# calculate loss
test_loss = criterion(output.squeeze(), labels.float())
test_losses.append(test_loss.item())
# convert output probabilities to predicted class (0 or 1)
pred = torch.round(output.squeeze()) # rounds to the nearest integer
# compare predictions to true label
correct_tensor = pred.eq(labels.float().view_as(pred))
correct = np.squeeze(correct_tensor.numpy()) if not train_on_gpu else np.squeeze(correct_tensor.cpu().numpy())
num_correct += np.sum(correct)
# -- stats! -- ##
# avg test loss
print("Test loss: {:.3f}".format(np.mean(test_losses)))
# accuracy over all test data
test_acc = num_correct/len(test_loader.dataset)
print("Test accuracy: {:.3f}".format(test_acc))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment