Skip to content

Instantly share code, notes, and snippets.

@Shiina18
Created December 26, 2023 08:59
Show Gist options
  • Save Shiina18/cee062009e2a3995afa3acc79c853b0c to your computer and use it in GitHub Desktop.
Save Shiina18/cee062009e2a3995afa3acc79c853b0c to your computer and use it in GitHub Desktop.
some simple python utilities copied from other sources
"""
Copied from sentence-transformers / sentence_transformers/evaluation/BinaryClassificationEvaluator.py
"""
def find_best_f1_and_threshold(scores, labels, high_score_more_similar: bool):
assert len(scores) == len(labels)
scores = np.asarray(scores)
labels = np.asarray(labels)
rows = list(zip(scores, labels))
rows = sorted(rows, key=lambda x: x[0], reverse=high_score_more_similar)
best_f1 = best_precision = best_recall = 0
threshold = 0
nextract = 0
ncorrect = 0
total_num_duplicates = sum(labels)
for i in range(len(rows)-1):
score, label = rows[i]
nextract += 1
if label == 1:
ncorrect += 1
if ncorrect > 0:
precision = ncorrect / nextract
recall = ncorrect / total_num_duplicates
f1 = 2 * precision * recall / (precision + recall)
if f1 > best_f1:
best_f1 = f1
best_precision = precision
best_recall = recall
threshold = (rows[i][0] + rows[i + 1][0]) / 2
return best_f1, best_precision, best_recall, threshold
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment