Skip to content

Instantly share code, notes, and snippets.

@itsuncheng
Created June 12, 2020 10:31
Show Gist options
  • Save itsuncheng/445fe28d4403b0b49075f04137ecb28f to your computer and use it in GitHub Desktop.
Save itsuncheng/445fe28d4403b0b49075f04137ecb28f to your computer and use it in GitHub Desktop.
# Evaluation Function
def evaluate(model, test_loader):
y_pred = []
y_true = []
model.eval()
with torch.no_grad():
for (labels, title, text, titletext), _ in test_loader:
labels = labels.type(torch.LongTensor)
labels = labels.to(device)
titletext = titletext.type(torch.LongTensor)
titletext = titletext.to(device)
output = model(titletext, labels)
_, output = output
y_pred.extend(torch.argmax(output, 1).tolist())
y_true.extend(labels.tolist())
print('Classification Report:')
print(classification_report(y_true, y_pred, labels=[1,0], digits=4))
cm = confusion_matrix(y_true, y_pred, labels=[1,0])
ax= plt.subplot()
sns.heatmap(cm, annot=True, ax = ax, cmap='Blues', fmt="d")
ax.set_title('Confusion Matrix')
ax.set_xlabel('Predicted Labels')
ax.set_ylabel('True Labels')
ax.xaxis.set_ticklabels(['FAKE', 'REAL'])
ax.yaxis.set_ticklabels(['FAKE', 'REAL'])
best_model = BERT().to(device)
load_checkpoint(destination_folder + '/model.pt', best_model)
evaluate(best_model, test_iter)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment