Skip to content

Instantly share code, notes, and snippets.

@joelthe1
Created October 28, 2019 18:36
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save joelthe1/aea847a5eb39520aa4908532ddb6bc4f to your computer and use it in GitHub Desktop.
Save joelthe1/aea847a5eb39520aa4908532ddb6bc4f to your computer and use it in GitHub Desktop.
A script to generate precision scoring for ranked embedding matches from BioBert.
import pickle
ref = pickle.load(open('/path/to/biobert_results/gold-reference.pkl', 'rb'))
res = pickle.load(open('/path/to/biobert_results/results.pkl', 'rb'))
scores = []
for res_key in res.keys():
if res_key == 'text':
continue
precisions = []
cleaned_res = [x[0] for x in res[res_key]]
# p@1
precisions.append(1 if len(set(ref[res_key]).intersection(cleaned_res[:1])) > 0 else 0)
# p@3
precisions.append(1 if len(set(ref[res_key]).intersection(cleaned_res[:3])) > 0 else 0)
# p@5
precisions.append(1 if len(set(ref[res_key]).intersection(cleaned_res[:5])) > 0 else 0)
# p@10
precisions.append(1 if len(set(ref[res_key]).intersection(cleaned_res[:10])) > 0 else 0)
scores.append(precisions)
print('p@1 =', sum([x[0] for x in scores])/len(scores))
print('p@3 =', sum([x[1] for x in scores])/len(scores))
print('p@5 =', sum([x[2] for x in scores])/len(scores))
print('p@10 =', sum([x[3] for x in scores])/len(scores))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment