Created
October 6, 2019 22:36
-
-
Save mooreniemi/01158d16d0a6bd64cdc94404e933395a to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from heapq import nsmallest, nlargest, heappush, heapreplace | |
from random import randrange, seed | |
from copy import deepcopy | |
# Taken directly from https://trevorcohn.github.io/comp90042/slides/WSTA_L3_IR.pdf 26+ | |
p_lists = { 'the': [2,3,7,8,9,10,11,12,13,17,18,19], | |
'quick': [5,6,11,14,18], | |
'brown': [2,4,5,15,42,84,96], | |
'fox': [5,7,8,13] } | |
all_doc_ids = set() | |
for _, doc_ids in p_lists.items(): | |
[all_doc_ids.add(d) for d in doc_ids] | |
# setting k to this effectively turns wand off | |
total_doc_count = len(all_doc_ids) | |
# really this should be generated from the same mechanism as the scores | |
# if it was dependent properly, we could switch weights to 1 and see full or | |
# http://citeseerx.ist.psu.edu/viewdoc/download?doi=10.1.1.365.2939&rep=rep1&type=pdf | |
max_c = {'the': 0.9, 'quick': 1.9, 'brown': 2.3, 'fox': 7.1} | |
seed(42) # this takes the place of a real similarity function | |
def rand_score(t, d): | |
score = round(randrange(1, max_c[t]*10)*0.1) | |
print('"%s" score in doc_id %i: %i' % (t, d, score)) | |
return score | |
pre_scored = { 'the': {}, 'quick': {}, 'brown': {}, 'fox': {} } | |
for t, ds in p_lists.items(): | |
for d in ds: | |
pre_scored[t][d] = rand_score(t, d) | |
# for double checking | |
# scored_docs = {'the': {2: 0, 3: 0, 7: 0, 8: 0, 9: 0, 10: 0, 11: 0, 12: 0, 13: 1, 17: 0, 18: 0, 19: 0}, 'quick': {5: 1, 6: 1, 11: 2, 14: 0, 18: 2}, 'brown': {2: 1, 4: 2, 5: 2, 15: 1, 42: 1, 84: 2, 96: 2}, 'fox': {5: 4, 7: 0, 8: 2, 13: 6}} | |
# assert(pre_scored == scored_docs, 'consistent scores') | |
# using scores from a pre_scored structure makes inspection easier | |
# so this is the scoring function used below | |
def fixed_score(t, doc_id): | |
score = pre_scored[t][doc_id] | |
print('fixed_score for "%s" and doc_id %i is %f' % (t, doc_id, score)) | |
return score | |
def wand(q, k): | |
terms = q.split(' ') | |
enabled = k < total_doc_count # makes log messages easier etc | |
times_scored, times_skipped = 0, 0 # for comparison | |
scores = [] # this will be turned into a heap below | |
to_be_scored = deepcopy(p_lists) # I find it easier to read mutated structures | |
while True: | |
assert len(scores) <= k, 'we never store more scores than k' | |
# if we've run out of postings for a t, remove it from to_be_scored | |
to_be_scored = {k:v for k, v in to_be_scored.items() if v} | |
if not to_be_scored: | |
print('ran out of postings') | |
break | |
print('still scoring: %r' % to_be_scored) | |
# we sort by the postings list doc_id, if it doesn't exist we set it to max | |
to_be_scored = dict(sorted(to_be_scored.items(), | |
key=lambda kv: kv[1][0] if kv[1] else float('inf'))) | |
current_doc_id = to_be_scored[list(to_be_scored.keys())[0]][0] | |
docs = dict(filter(lambda kv: current_doc_id in kv[1], to_be_scored.items())) | |
active_terms = [t for t, ds in docs.items()] # for inspection | |
max_score = sum([max_c[t] for t, ps in docs.items()]) | |
print('max_score of %f for %r' % (max_score, active_terms)) | |
try: | |
# this will never change if k is not set, limiting the heap | |
current_lowest = nsmallest(1, scores, key=lambda x: x[0])[0][0] | |
except IndexError as e: | |
# I'm being a bit lazy here with the heap initialization... | |
print('should only see this once... (%r)' % e) | |
current_lowest = 0 | |
should_score = max_score > current_lowest if enabled else True | |
if should_score: | |
if enabled: | |
message = 'scoring because max score: %f > current_lowest: %f' | |
else: | |
message = 'scoring because max score: %f ignored, current_lowest: %f' | |
print(message % (max_score, current_lowest)) | |
times_scored += 1 | |
score = 0 | |
for t, ps in docs.items(): | |
score += fixed_score(t, current_doc_id) | |
to_be_scored[t].remove(current_doc_id) # once scored, don't process again | |
print('summed score was %f, lowest score was %f' % (score, current_lowest)) | |
if score > current_lowest: | |
print('found a new score: %f' % score) | |
if len(scores) < k: | |
heappush(scores, [score, current_doc_id]) | |
else: | |
heapreplace(scores, [score, current_doc_id]) | |
else: | |
print('WAND skipped because max_score: %f < current_lowest: %f' % (max_score, current_lowest)) | |
times_skipped += 1 | |
for t, ps in docs.items(): | |
to_be_scored[t].remove(current_doc_id) | |
print('WAND enabled? %r' % enabled) | |
print('times scored: %i, times skipped: %i' % (times_scored, times_skipped)) | |
print('final scores: %r' % scores) | |
# below will return highest score and highest document id, which is fine for this example | |
top_score = nlargest(1, scores) | |
return top_score | |
q = 'the quick brown fox' | |
no_wand = wand(q, k=total_doc_count) | |
with_wand = wand(q, k=2) | |
print('\n' + ('#'*32) + ' RESULTS') | |
print('with_wand %r == no_wand %r' % (with_wand, no_wand)) | |
# if you want to verify maximum iterations we should've seen | |
# print('total_doc_count: %i (compare to times_scored)' % total_doc_count) | |
# if you want to verify scores | |
# print('pre_scored docs: %r' % pre_scored) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Here's the output: