Search relevance evaluation metrics
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
## | |
## Python implementations of the search relevance evaluation metrics described at | |
## https://opensourceconnections.com/blog/2020/02/28/choosing-your-search-relevance-metric/ | |
## | |
## | |
def precision(docs): | |
return sum(docs) / len(docs) if docs else 0 | |
def avg_precision(docs): | |
vals_to_avg = [precision(docs[:i+1]) for (i, doc) in enumerate(docs) if doc == 1] | |
return sum(vals_to_avg) / len(vals_to_avg) if vals_to_avg else 0 | |
def cumulative_gain(docs): | |
return sum(docs) | |
def discounted_cumulative_gain(docs): | |
from math import log2 | |
scores_to_sum = [d / log2(i+2) for (i,d) in enumerate(docs)] | |
return sum(scores_to_sum) | |
def alternative_discounted_cumulative_gain(docs): | |
from math import log2 | |
scores_to_sum = [(2**d - 1)/ log2(i+2) for (i, d) in enumerate(docs)] | |
return sum(scores_to_sum) | |
def normalized_discounted_cumulative_gain(docs): | |
topK = 5 | |
real = discounted_cumulative_gain(docs[:topK]) | |
ideal = discounted_cumulative_gain(sorted(docs, reverse=True)[:topK]) | |
return real / ideal if ideal else 0 | |
if __name__ == '__main__': | |
docs1 = (1,1,1,0,0) | |
docs2 = (0,0,1,1,1) | |
print("precision") | |
print(precision(docs1), precision(docs2), '\n') | |
print("avg_precision") | |
print(avg_precision(docs1), avg_precision(docs2), '\n') | |
# example used in https://trec.nist.gov/pubs/trec15/appendices/CE.MEASURES06.pdf | |
docs3 = (1, 1, 0, 1, 0, 0, 1) | |
print("avg_precision") | |
print(avg_precision(docs3), '\n') | |
print("cumulative gain") | |
docsG1 = (4,3,2,1,0) | |
docsG2 = tuple(reversed(docsG1)) | |
print(cumulative_gain(docsG1), cumulative_gain(docsG2), '\n') | |
print("discounted_cumulative_gain") | |
print(discounted_cumulative_gain(docsG1), discounted_cumulative_gain(docsG2), '\n') | |
print("alternative_discounted_cumulative_gain") | |
print(alternative_discounted_cumulative_gain(docsG1), alternative_discounted_cumulative_gain(docsG2), '\n') | |
query1 = (4, 4, 3, 3, 3) | |
query2 = (2, 1, 1, 1, 0) | |
print("alternative_discounted_cumulative_gain") | |
print(alternative_discounted_cumulative_gain(query1), alternative_discounted_cumulative_gain(query2), '\n') | |
print("normalized_discounted_cumulative_gain") | |
print(normalized_discounted_cumulative_gain(query1), normalized_discounted_cumulative_gain(query2), '\n') | |
print("normalized_discounted_cumulative_gain") | |
query3 = (3,2,1,4,0) | |
print(normalized_discounted_cumulative_gain(query3), '\n') | |
print("normalized_discounted_cumulative_gain") | |
query4 = (0,1,2,3,4) | |
print(normalized_discounted_cumulative_gain(query4), '\n') | |
print("normalized_discounted_cumulative_gain") | |
docs_all = (4,3,2,1,1,0,3,4,0,0) | |
print(normalized_discounted_cumulative_gain(docs_all), '\n') |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment