Last active
November 7, 2017 20:14
-
-
Save JasonTam/40bd3f5144578b4fc258310187e5af67 to your computer and use it in GitHub Desktop.
example of how to freeze embeddings in lightfm by exploiting the accumulated gradient in adagrad
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"cells": [ | |
{ | |
"cell_type": "code", | |
"execution_count": 1, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stderr", | |
"output_type": "stream", | |
"text": [ | |
"/Users/jtam/tools/lightfm/lightfm/_lightfm_fast.py:9: UserWarning: LightFM was compiled without OpenMP support. Only a single thread will be used.\n", | |
" warnings.warn('LightFM was compiled without OpenMP support. '\n" | |
] | |
} | |
], | |
"source": [ | |
"import numpy as np\n", | |
"\n", | |
"from lightfm.datasets import fetch_movielens\n", | |
"\n", | |
"movielens = fetch_movielens()" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 2, | |
"metadata": { | |
"collapsed": true | |
}, | |
"outputs": [], | |
"source": [ | |
"train = movielens['train']\n", | |
"test = movielens['test']" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"## Fitting models" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 3, | |
"metadata": { | |
"collapsed": true | |
}, | |
"outputs": [], | |
"source": [ | |
"from lightfm import LightFM\n", | |
"from lightfm.evaluation import precision_at_k\n", | |
"from lightfm.evaluation import auc_score\n", | |
"import sys" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 20, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"n_dims = 10\n", | |
"model = LightFM(learning_rate=0.05, loss='bpr', random_state=322)\n", | |
"\n", | |
"# Factor transfer\n", | |
"model._initialize(no_components=n_dims,\n", | |
" no_item_features=train.shape[1],\n", | |
" no_user_features=train.shape[0],)\n", | |
"\n", | |
"# Set embeddings for 2 users\n", | |
"model.user_embeddings[0,:] = 0.5 # we expect this to be the same\n", | |
"model.user_embeddings[1,:] = 0.7 # we expect this to change\n", | |
"\n", | |
"# Freeze only the first one\n", | |
"model.user_embedding_gradients[0,:] = sys.maxsize\n", | |
"model.user_bias_gradients[0] = sys.maxsize" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 21, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"<lightfm.lightfm.LightFM at 0x10e58ff60>" | |
] | |
}, | |
"execution_count": 21, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"model.fit_partial(train, epochs=10)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 22, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"Precision: train 0.59, test 0.10.\n", | |
"AUC: train 0.89, test 0.86.\n" | |
] | |
} | |
], | |
"source": [ | |
"train_precision = precision_at_k(model, train, k=10).mean()\n", | |
"test_precision = precision_at_k(model, test, k=10).mean()\n", | |
"\n", | |
"train_auc = auc_score(model, train).mean()\n", | |
"test_auc = auc_score(model, test).mean()\n", | |
"\n", | |
"print('Precision: train %.2f, test %.2f.' % (train_precision, test_precision))\n", | |
"print('AUC: train %.2f, test %.2f.' % (train_auc, test_auc))" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 23, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"array([[ 0.5 , 0.5 , 0.5 , 0.5 , 0.5 ,\n", | |
" 0.5 , 0.5 , 0.5 , 0.5 , 0.5 ],\n", | |
" [ 0.09772108, -0.03140765, 1.07019258, 1.09194338, 0.03139753,\n", | |
" 0.4900538 , 1.37998486, 1.55365562, 0.36961207, 0.29496703]], dtype=float32)" | |
] | |
}, | |
"execution_count": 23, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"model.user_embeddings[:2,:]" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 24, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"array([ -1.78466397e-08, -1.55268562e+00], dtype=float32)" | |
] | |
}, | |
"execution_count": 24, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"model.user_biases[:2]" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": { | |
"collapsed": true | |
}, | |
"outputs": [], | |
"source": [ | |
"# " | |
] | |
} | |
], | |
"metadata": { | |
"kernelspec": { | |
"display_name": "Python [ipykernel_py3]", | |
"language": "python", | |
"name": "Python [ipykernel_py3]" | |
}, | |
"language_info": { | |
"codemirror_mode": { | |
"name": "ipython", | |
"version": 3 | |
}, | |
"file_extension": ".py", | |
"mimetype": "text/x-python", | |
"name": "python", | |
"nbconvert_exporter": "python", | |
"pygments_lexer": "ipython3", | |
"version": "3.6.0" | |
}, | |
"widgets": { | |
"state": {}, | |
"version": "1.1.2" | |
} | |
}, | |
"nbformat": 4, | |
"nbformat_minor": 1 | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment