Skip to content

Instantly share code, notes, and snippets.

@rwalk
Last active June 24, 2020 03:55
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save rwalk/384b9cce2e83c1502f607b2187176789 to your computer and use it in GitHub Desktop.
Save rwalk/384b9cce2e83c1502f607b2187176789 to your computer and use it in GitHub Desktop.
Co-Occurence based recommendation
import argparse
import json
import sys
import warnings
import numpy as np
from scipy.sparse import load_npz, coo_matrix
class CooccurenceRecommender:
def __init__(self, U, items):
'''
A Co-occurence based recommendation engine
U: sparse CSR matrix n users by m items
items: list(dict) of item metadata where each dict contains at least the keys title, index
'''
self._U = U
self._items = items
def _item_index_lookup(self, **kwargs):
indexes = []
for item in self._items:
match = False
for k,v in kwargs.items():
# check each condition
field_value = item.get(k)
value = v.lower()
if field_value:
if type(field_value) is str:
if value in field_value.lower():
match = True
else:
match = False
break
elif type(field_value) is list and all([type(f) is str for f in field_value]):
if any([value in f.lower() for f in field_value]):
match = True
else:
match = False
break
else:
raise ValueError(f"Field {k} is not queryable!")
else:
match = False
break
if match:
print(f"Matched item: {item}")
indexes.append(item["index"])
if len(indexes) > 25:
warnings.warn("More than 25 items matched this query. Only taking first 10.")
return indexes
return indexes
def _build_query_vector(self, indexes):
# build the query vector
data, I, J = [], [], []
for idx in indexes:
if idx:
I.append(idx)
J.append(0)
data.append(1)
q = coo_matrix((data, (I, J)), shape=(self._U.shape[-1], 1), dtype=np.float64).tocsr()
return q
def _score(self, q, number):
y = self._U.transpose().dot(self._U.dot(q))
recs = [{
"item": items[i],
"score": float(score),
} for i, score in zip(y.indices, y.data)
]
recs.sort(key=lambda x: x["score"], reverse=True)
return recs[0:number]
def recommend(self, number=10, **kwargs):
indexes = self._item_index_lookup(**kwargs)
q = self._build_query_vector(indexes)
return self._score(q, number)
if __name__ == "__main__":
parser = argparse.ArgumentParser("Cooccurence Recommender")
parser.add_argument("matrix_file", help="Sparse user item matrix in npz format")
parser.add_argument("items_file", help="File of JSON array where each item contains at least the keys title, index")
args = parser.parse_args()
U = load_npz(args.matrix_file)
with open(args.items_file) as f:
items = json.load(f)
recommender = CooccurenceRecommender(U, items)
try:
while True:
q_string = input("Enter a query as JSON (type 'example' for help):\n")
if q_string.lower().strip() == "example":
print("Example: {\"authors\": \"Paul Bowles\", \"title\": \"Sky\"}")
elif len(q_string.strip()) == 0:
pass
else:
try:
query = json.loads(q_string)
for hit in recommender.recommend(number=5, **query):
print(json.dumps(hit, indent=2))
except json.decoder.JSONDecodeError:
print("Query is not valid JSON!")
except KeyboardInterrupt:
sys.exit(0)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment