Last active
December 4, 2018 18:39
-
-
Save bnsh/45d7ddcfc331ac0992267b7eb053a2d4 to your computer and use it in GitHub Desktop.
This program contains two things. 1. a precision that _will_ output 1.0 if all the ranks are indeed perfect, regardless of the number of items that the user had rated, and 2. a pseudo "classifier" that always outputs exactly the right predictions, by cheating, effectively. (Just to show what the best one could do with ranking, and how it might c…
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#! /usr/bin/env python3 | |
"""This is the MovieLens example from lightfm""" | |
import numpy as np | |
from scipy.sparse import csc_matrix | |
from lightfm.datasets import fetch_movielens | |
from lightfm import LightFM | |
from lightfm.evaluation import precision_at_k | |
#pylint: disable=too-many-arguments | |
def strange_precision_at_k(model, test_interactions, train_interactions=None, | |
k=10, user_features=None, item_features=None, | |
preserve_rows=False, num_threads=1, check_intersections=True): | |
ranks = model.predict_rank(test_interactions, | |
train_interactions=train_interactions, | |
user_features=user_features, | |
item_features=item_features, | |
num_threads=num_threads, | |
check_intersections=check_intersections, | |
) | |
ranks.data = np.less(ranks.data, k, ranks.data) | |
denominator = np.clip(np.asarray((test_interactions != 0).astype(np.float).sum(axis=1)), 1, k) | |
precision = np.squeeze(np.asarray(ranks.sum(axis=1)) / denominator) | |
if not preserve_rows: | |
precision = precision[test_interactions.getnnz(axis=1) > 0] | |
return precision | |
#pylint: enable=too-many-arguments | |
class Cheater(object): | |
def __init__(self, loss): | |
self._loss = loss | |
def fit(self, interactions, **dummy_kwargs): | |
pass | |
def fit_partial(self, **dummy_kwargs): | |
pass | |
#pylint: disable=no-self-use | |
def predict_rank(self, test_interactions, **dummy_kwargs): | |
data = [] | |
lastrow = None | |
maxcol = max(test_interactions.col) + 1 | |
rows = [] | |
cols = [] | |
for row, col in sorted(zip(test_interactions.row, test_interactions.col), key=lambda x: x[0] * maxcol + x[1]): | |
if lastrow != row: | |
rank = 0.0 | |
data.append(rank) | |
rows.append(row) | |
cols.append(col) | |
rank += 1.0 | |
lastrow = row | |
return csc_matrix((data, (rows, cols))) | |
#pylint: enable=no-self-use | |
def comparison(label, data, model): | |
model.fit(data["train"], epochs=30, num_threads=2) | |
print("%s Standard Train precision: %.2f" % (label, precision_at_k(model, data['train'], k=5).mean(),)) | |
print("%s Strange Train precision: %.2f\n" % (label, strange_precision_at_k(model, data['train'], k=5).mean(),)) | |
print("%s Standard Test precision: %.2f" % (label, precision_at_k(model, data['test'], k=5).mean(),)) | |
print("%s Strange Test precision: %.2f\n" % (label, strange_precision_at_k(model, data['test'], k=5).mean(),)) | |
def main(): | |
data = fetch_movielens(min_rating=5.0) | |
comparison("LightFM", data, LightFM(loss="warp")) | |
comparison("Cheater", data, Cheater(loss="warp")) | |
if __name__ == "__main__": | |
main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment