Skip to content

Instantly share code, notes, and snippets.

@simonw
Last active January 6, 2019 22:06
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save simonw/e0b9156d66b41b172a66d0cfe32d9391 to your computer and use it in GitHub Desktop.
Save simonw/e0b9156d66b41b172a66d0cfe32d9391 to your computer and use it in GitHub Desktop.
Demonstrating a bug in Peewee's bm25 function - see https://github.com/coleifer/peewee/issues/1826
import math
import struct
import sqlite3
conn = sqlite3.connect(":memory:")
conn.executescript("""
CREATE VIRTUAL TABLE docs USING fts4(c0, c1);
INSERT INTO docs (c0, c1) VALUES ("this is about a dog", "more about that dog dog");
INSERT INTO docs (c0, c1) VALUES ("this is about a cat", "stuff on that cat cat");
INSERT INTO docs (c0, c1) VALUES ("something about a ferret", "yeah a ferret ferret");
INSERT INTO docs (c0, c1) VALUES ("both of them", "both dog dog and cat here");
INSERT INTO docs (c0, c1) VALUES ("not mammals", "maybe talk about fish");
""")
def _parse_match_info(buf):
bufsize = len(buf) # Length in bytes.
return [struct.unpack('@I', buf[i:i+4])[0] for i in range(0, bufsize, 4)]
def bm25(match_info, *args):
"""
Usage:
# Format string *must* be pcnalx
# Second parameter to bm25 specifies the index of the column, on
# the table being queries.
bm25(matchinfo(document_tbl, 'pcnalx'), 1) AS rank
"""
K = 1.2
B = 0.75
score = 0.0
P_O, C_O, N_O, A_O = range(4)
term_count = match_info[P_O]
col_count = match_info[C_O]
total_docs = match_info[N_O]
print("term_count={}, col_count={}, total_docs={}".format(
term_count, col_count, total_docs
))
L_O = A_O + col_count
X_O = L_O + col_count
if not args:
weights = [1] * col_count
else:
weights = [0] * col_count
for i, weight in enumerate(args):
weights[i] = args[i]
for i in range(term_count):
for j in range(col_count):
weight = weights[j]
if weight == 0:
continue
print("term (i) = {}, column (j) = {}".format(i, j))
avg_length = float(match_info[A_O + j])
doc_length = float(match_info[L_O + j])
print(" avg_length={}, doc_length={}".format(avg_length, doc_length))
if avg_length == 0:
D = 0
else:
D = 1 - B + (B * (doc_length / avg_length))
x = X_O + (3 * j * (i + 1))
term_frequency = float(match_info[x])
docs_with_term = float(match_info[x + 2])
print(" term_frequency_in_this_column={}, docs_with_term_in_this_column={}".format(
term_frequency, docs_with_term
))
idf = max(
math.log(
(total_docs - docs_with_term + 0.5) /
(docs_with_term + 0.5)),
0)
denom = term_frequency + (K * D)
if denom == 0:
rhs = 0
else:
rhs = (term_frequency * (K + 1)) / denom
score += (idf * rhs) * weight
return -score
for search in ("dog", "dog cat"):
results = conn.execute("""
select *, matchinfo(docs, 'pcnalx') from docs
where docs match ?
""", [search]).fetchall()
print('search = {}'.format(search))
print("============")
for r in results:
print(r[:2])
print(_parse_match_info(r[-1]))
print(bm25(_parse_match_info(r[-1])))
print()
search = dog
============
('this is about a dog', 'more about that dog dog')
[1, 2, 5, 4, 5, 5, 5, 1, 1, 1, 2, 4, 2]
term_count=1, col_count=2, total_docs=5
term (i) = 0, column (j) = 0
avg_length=4.0, doc_length=5.0
term_frequency_in_this_column=1.0, docs_with_term_in_this_column=1.0
term (i) = 0, column (j) = 1
avg_length=5.0, doc_length=5.0
term_frequency_in_this_column=2.0, docs_with_term_in_this_column=2.0
-1.45932851507369
('both of them', 'both dog dog and cat here')
[1, 2, 5, 4, 5, 3, 6, 0, 1, 1, 2, 4, 2]
term_count=1, col_count=2, total_docs=5
term (i) = 0, column (j) = 0
avg_length=4.0, doc_length=3.0
term_frequency_in_this_column=0.0, docs_with_term_in_this_column=1.0
term (i) = 0, column (j) = 1
avg_length=5.0, doc_length=6.0
term_frequency_in_this_column=2.0, docs_with_term_in_this_column=2.0
-0.438011195601579
search = dog cat
============
('both of them', 'both dog dog and cat here')
[2, 2, 5, 4, 5, 3, 6, 0, 1, 1, 2, 4, 2, 0, 1, 1, 1, 3, 2]
term_count=2, col_count=2, total_docs=5
term (i) = 0, column (j) = 0
avg_length=4.0, doc_length=3.0
term_frequency_in_this_column=0.0, docs_with_term_in_this_column=1.0
term (i) = 0, column (j) = 1
avg_length=5.0, doc_length=6.0
term_frequency_in_this_column=2.0, docs_with_term_in_this_column=2.0
term (i) = 1, column (j) = 0
avg_length=4.0, doc_length=3.0
term_frequency_in_this_column=0.0, docs_with_term_in_this_column=1.0
term (i) = 1, column (j) = 1
avg_length=5.0, doc_length=6.0
term_frequency_in_this_column=0.0, docs_with_term_in_this_column=1.0
-0.438011195601579
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment