Skip to content

Instantly share code, notes, and snippets.

@megha444
Last active October 10, 2020 07:04
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save megha444/2846f73f74fdac01b28a32fbcd86b79c to your computer and use it in GitHub Desktop.
Save megha444/2846f73f74fdac01b28a32fbcd86b79c to your computer and use it in GitHub Desktop.
#load weights of best model
pathw = 'saved_weights.pt'
model_def.load_state_dict(torch.load(pathw))
#FINE TUNE FOR TEST DATA
# get predictions for test data
with torch.no_grad():
pred = model_def(testseq.to(device), testmask.to(device))
pred = pred.detach().cpu().numpy()
#TO CHECK MODEL PERFORMANCE
pred = np.argmax(pred, axis = 1)
print(classification_report(test_y, pred))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment