Skip to content

Instantly share code, notes, and snippets.

@Chiang97912
Last active August 1, 2022 07:13
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save Chiang97912/5ebca5f3fa58eff4a096119dd356e032 to your computer and use it in GitHub Desktop.
Save Chiang97912/5ebca5f3fa58eff4a096119dd356e032 to your computer and use it in GitHub Desktop.
# -*- coding:utf-8 -*-
def MRR(ranked_list, ground_truth):
""" 平均倒排名 """
rr = 0.
for i in range(len(ranked_list)):
for j in range(len(ranked_list[i])):
# if ground_truth[i][0] == ranked_list[i][j]:
if ranked_list[i][j] in ground_truth[i]:
rr += 1/(j+1) # 注意j的取值从0开始
break
mrr = rr / len(ground_truth)
return mrr
def HitRatio(ranked_list, ground_truth):
""" 命中率 """
hits = 0.
for i in range(len(ranked_list)):
recom_set = set(ranked_list[i])
truth_set = set(ground_truth[i])
n_union = len(recom_set & truth_set)
if n_union > 0:
hits += 1
return hits / len(ground_truth)
def AP(ranked_list, ground_truth):
""" 精度均值(Average Precision,简称AP) """
hits = 0
sum_precs = 0
for i in range(len(ranked_list)):
if ranked_list[i] in ground_truth:
hits += 1
sum_precs += hits / (i + 1.0)
if hits > 0:
return sum_precs / len(ground_truth)
else:
return 0
def MAP(ranked_list, ground_truth):
""" 平均精度均值(Mean Average Precision,简称MAP) """
ap = 0
mAP = 0
for i in range(len(ranked_list)):
ap += AP(ranked_list[i], ground_truth[i])
mAP = ap / len(ground_truth)
return mAP
def Precision(ranked_list, ground_truth):
""" 精确率 """
n_union = 0.
recommend_sum = 0.
for i in range(len(ranked_list)):
recom_set = set(ranked_list[i])
truth_set = set(ground_truth[i])
n_union += len(recom_set & truth_set)
recommend_sum += len(recom_set)
return n_union / recommend_sum
def Recall(ranked_list, ground_truth):
""" 召回率 """
n_union = 0.
user_sum = 0.
for i in range(len(ranked_list)):
recom_set = set(ranked_list[i])
truth_set = set(ground_truth[i])
n_union += len(recom_set & truth_set)
user_sum += len(truth_set)
return n_union / user_sum
def Precision_V1(recommends, tests):
""" 精确率 """
n_union = 0.
recommend_sum = 0.
for user_id, items in recommends.items():
recommend_set = set(items)
test_set = set(tests[user_id])
n_union += len(recommend_set & test_set)
recommend_sum += len(recommend_set)
return n_union / recommend_sum
def Recall_V1(recommends, tests):
""" 召回率 """
n_union = 0.
user_sum = 0.
for user_id, items in recommends.items():
recommend_set = set(items)
test_set = set(tests[user_id])
n_union += len(recommend_set & test_set)
user_sum += len(test_set)
return n_union / user_sum
if __name__ == '__main__':
# 推荐列表
R = [[3, 10, 15, 12, 17], [20, 15, 18, 14, 30], [2, 5, 7, 8, 15], [56, 14, 25, 12, 19], [21, 24, 36, 54, 45]]
# 用户访问列表
T = [[12], [3], [5], [14], [20]]
# T = [[12, 3, 17, 15], [3], [5, 15, 8], [14], [20, 24]]
print(MRR(R, T))
print(HitRatio(R, T))
print(MAP(R, T))
print(Precision(R, T))
print(Recall(R, T))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment