Skip to content

Instantly share code, notes, and snippets.

@ita9naiwa
Last active June 4, 2019 09:20
Show Gist options
  • Save ita9naiwa/c4b9adcd3707d98ead72f81bcad0a3cd to your computer and use it in GitHub Desktop.
Save ita9naiwa/c4b9adcd3707d98ead72f81bcad0a3cd to your computer and use it in GitHub Desktop.
logistic matrix factorization
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "code",
"execution_count": 2,
"metadata": {
"scrolled": true
},
"outputs": [],
"source": [
"import pandas as pd \n",
"import os \n",
"import numpy as np"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [],
"source": [
"r = pd.read_csv(\"movielens/u.data\", sep='\\t')\n",
"r.columns =['user', 'item', 'rating','ts']"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>user</th>\n",
" <th>item</th>\n",
" <th>rating</th>\n",
" <th>ts</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>186</td>\n",
" <td>302</td>\n",
" <td>3</td>\n",
" <td>891717742</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1</th>\n",
" <td>22</td>\n",
" <td>377</td>\n",
" <td>1</td>\n",
" <td>878887116</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2</th>\n",
" <td>244</td>\n",
" <td>51</td>\n",
" <td>2</td>\n",
" <td>880606923</td>\n",
" </tr>\n",
" <tr>\n",
" <th>3</th>\n",
" <td>166</td>\n",
" <td>346</td>\n",
" <td>1</td>\n",
" <td>886397596</td>\n",
" </tr>\n",
" <tr>\n",
" <th>4</th>\n",
" <td>298</td>\n",
" <td>474</td>\n",
" <td>4</td>\n",
" <td>884182806</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>"
],
"text/plain": [
" user item rating ts\n",
"0 186 302 3 891717742\n",
"1 22 377 1 878887116\n",
"2 244 51 2 880606923\n",
"3 166 346 1 886397596\n",
"4 298 474 4 884182806"
]
},
"execution_count": 10,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"r.head()"
]
},
{
"cell_type": "code",
"execution_count": 16,
"metadata": {},
"outputs": [],
"source": [
"uid_to_idx = {y:x for x, y in enumerate(r.user.unique().tolist())}\n",
"iid_to_idx = {y:x for x, y in enumerate(r.item.unique().tolist())}\n",
"idx_to_uid = {x:y for x, y in enumerate(r.user.unique().tolist())}\n",
"idx_to_iid = {x:y for x, y in enumerate(r.item.unique().tolist())}"
]
},
{
"cell_type": "code",
"execution_count": 17,
"metadata": {},
"outputs": [],
"source": [
"r['uid'] = r.user.map(lambda x: uid_to_idx[x])\n",
"r['iid'] = r.item.map(lambda x: iid_to_idx[x])"
]
},
{
"cell_type": "code",
"execution_count": 20,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(943, 1682)"
]
},
"execution_count": 20,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"U = len(uid_to_idx)\n",
"M = len(iid_to_idx)\n",
"(U, M)"
]
},
{
"cell_type": "code",
"execution_count": 37,
"metadata": {},
"outputs": [],
"source": [
"u_items = dict(r.groupby('uid').iid.apply(np.array))\n",
"i_users = dict(r.groupby('iid').uid.apply(np.array))"
]
},
{
"cell_type": "code",
"execution_count": 39,
"metadata": {},
"outputs": [],
"source": [
"def calc(u, v, ub=None, vb=None):\n",
" ret = np.dot(u,v.T)\n",
" if ub is not None and vb is not None:\n",
" b = ub + vb\n",
" ret += b\n",
" return np.exp(ret)\n",
" \n",
"def update(u, u_vec, i_vec, pos_views, U, M, user=True, lr=0.003, lamb=0.01):\n",
" uv = u_vec[u, :-2]\n",
" if user:\n",
" ub = uv[-1]\n",
" else:\n",
" ub = uv[-2]\n",
" \n",
" try:\n",
" p_v = pos_views[u]\n",
" #p_scores = 1 + np.log(1 + pos_views[u][:, 1])\n",
" p_scores = np.ones_like(p_v)\n",
" except:\n",
" p_v = []\n",
" p_scores = []\n",
" \n",
" n_v = np.random.choice(M, np.min([M // 2, len(p_v) * 10]), replace=False)\n",
" p_v = np.hstack([p_v, n_v])\n",
" p_scores = np.hstack([p_scores, np.zeros_like(n_v)])\n",
" i_v = i_vec[p_v, :-2]\n",
" \n",
" if user:\n",
" ib = i_v[:, -2]\n",
" else:\n",
" ib = i_v[:, -1] \n",
" exp = calc(uv, i_v, ub, ib) \n",
" B = (1 + p_scores) * exp\n",
" C = np.divide(B, 1 + exp)\n",
"\n",
" A = np.dot(p_scores, i_v)\n",
" D = np.dot(C, i_v)\n",
" du = (A - D ) - lamb * uv\n",
" dv = (np.sum(p_scores) - np.sum(C)) - lamb * ub\n",
" ret = np.ones_like(u_vec[u])\n",
" ret[:-2] = (uv + lr * du) \n",
" if user:\n",
" ret[-1] = (ub + lr * dv)\n",
" else:\n",
" ret[-2] = (ub + lr * dv)\n",
" return ret"
]
},
{
"cell_type": "code",
"execution_count": 66,
"metadata": {},
"outputs": [],
"source": [
"us = np.random.normal(0, 0.01, size=(U, 20))\n",
"vs = np.random.normal(0, 0.01, size=(M, 20))\n",
"us[:, -2] = 1.0\n",
"vs[:, -1] = 1.0"
]
},
{
"cell_type": "code",
"execution_count": 67,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"18.45311222259154 6.606965804260298\n",
"0.0007069635913750441\n",
"9.76411791693294 3.774571727624158\n",
"0.0007069635913750441\n",
"6.892110674011095 2.8077871545472375\n",
"0.00035348179568752205\n",
"5.5052403815692665 2.28960475840313\n",
"0.0010604453870625664\n",
"4.79944692348174 1.9639428104523216\n",
"0.0007069635913750441\n",
"4.59048677812209 1.7937426291608674\n",
"0.0010604453870625664\n",
"4.6062514549870155 1.7530676488786767\n",
"0.5330505478967833\n",
"4.495584754615364 1.699913457535704\n",
"0.5991516436903499\n",
"4.568251496663015 1.498026963787225\n",
"0.3092965712265818\n",
"4.626760772762859 1.321821346978116\n",
"0.416401555319901\n",
"4.683963123532821 1.2238422815609613\n",
"0.5722870272180982\n",
"4.749517847836844 1.1723391925712852\n",
"0.5638034641215977\n",
"4.804234120991001 1.1423121225481827\n",
"0.558501237186285\n",
"4.8564117416545685 1.1208772777918459\n",
"0.5652173913043478\n",
"4.90401400731277 1.102984114240213\n",
"0.5652173913043478\n",
"4.962556749806057 1.0918275514967524\n",
"0.5641569459172853\n",
"5.018016815109204 1.0809830455094898\n",
"0.5539059738423471\n",
"5.078056516083909 1.072948230232887\n",
"0.5662778366914104\n",
"5.1344621115701745 1.065829642540267\n",
"0.5680452456698479\n",
"5.188347239423277 1.059868401759393\n",
"0.5698126546482856\n",
"5.244511003839926 1.053877044812167\n",
"0.5662778366914104\n",
"5.322076176656392 1.0497995231928663\n",
"0.5659243548957228\n",
"5.382694238855233 1.0455908343524027\n",
"0.5694591728525981\n",
"5.438910969798368 1.0419320423753113\n",
"0.5563803464121597\n",
"5.511768074618116 1.0399761063760051\n",
"0.5652173913043478\n",
"5.567225293720422 1.0370519100806548\n",
"0.5676917638741604\n",
"5.627514696554446 1.0339099272087968\n",
"0.5627430187345351\n",
"5.687694915310855 1.0325661132701782\n",
"0.5652173913043478\n",
"5.753047687460456 1.0302156441159953\n",
"0.5645104277129727\n",
"5.80741499007184 1.0284133269487519\n",
"0.5652173913043478\n",
"5.866707220754546 1.0262342582214674\n",
"0.5630965005302226\n",
"5.915964358758941 1.0254664716738577\n",
"0.5645104277129727\n",
"5.973055210029575 1.0242211982094316\n",
"0.5553199010250972\n",
"6.027916067315448 1.0219485114146472\n",
"0.5652173913043478\n",
"6.074863071455239 1.0208621962433566\n",
"0.5581477553905974\n",
"6.130656324766482 1.020135384814659\n",
"0.5634499823259101\n",
"6.185412187877836 1.019224569793171\n",
"0.5659243548957228\n",
"6.229137304952745 1.0188626421962705\n",
"0.5705196182396607\n",
"6.28340675799909 1.0184204816128182\n",
"0.5659243548957228\n",
"6.326114415423463 1.0174666603073765\n",
"0.546482856132909\n",
"6.3740138178674695 1.0171226860789873\n",
"0.5655708731000353\n",
"6.42458640585839 1.016829027481995\n",
"0.5528455284552846\n",
"6.467422640364974 1.0165559684002063\n",
"0.5652173913043478\n",
"6.52060838230271 1.0160755116423004\n",
"0.5556733828207847\n",
"6.563432141907908 1.0164130119427142\n",
"0.5648639095086603\n",
"6.60696314536215 1.0161590850338782\n",
"0.5574407917992223\n",
"6.657217106499905 1.0157642592435365\n",
"0.5648639095086603\n",
"6.70232948902955 1.0160164989135207\n",
"0.5574407917992225\n",
"6.740703431844464 1.0162361366950343\n",
"0.55920820077766\n",
"6.784919119669208 1.0162324610383728\n",
"0.5655708731000353\n",
"6.8277866256369215 1.016018558890994\n",
"0.5648639095086603\n",
"6.874355150160551 1.0163753279847116\n",
"0.5592082007776599\n",
"6.912767653332144 1.0170418286165273\n",
"0.5683987274655355\n",
"6.950781013970135 1.0181478721106176\n",
"0.5609756097560976\n",
"6.994799065653214 1.0185419075189077\n",
"0.5669848002827855\n",
"7.036546737220167 1.0193205485674008\n",
"0.5733474726051608\n",
"7.08650726449552 1.0200182896334\n",
"0.574054436196536\n",
"7.125491524508782 1.020800321006764\n",
"0.5956168257334749\n",
"7.165805918193476 1.021925544538118\n",
"0.5765288087663485\n",
"7.199342372138001 1.0230095166623712\n",
"0.5804171085189113\n",
"7.238828977932087 1.0249046304836948\n",
"0.5683987274655355\n",
"7.274748374796994 1.026875165523167\n",
"0.5973842347119123\n",
"7.306434753308298 1.0283718220391616\n",
"0.5871332626369743\n",
"7.344057225651415 1.030389686549755\n",
"0.5995051254860374\n",
"7.379214637596634 1.0327720049965796\n",
"0.6051608342170378\n",
"7.416560604789014 1.0350162628461552\n",
"0.6058677978084128\n",
"7.4519428799100265 1.0372572388898782\n",
"0.5860728172499116\n",
"7.490489270819507 1.0401367817685745\n",
"0.6246023329798516\n",
"7.5178022807536315 1.0432004485580937\n",
"0.634499823259102\n",
"7.558610620054508 1.0468211281910573\n",
"0.630258041710852\n",
"7.594056322331437 1.0503815651097952\n",
"0.6327324142806645\n",
"7.619939210804441 1.055033580123976\n",
"0.6560622127960409\n",
"7.660712793173402 1.0592961611168363\n",
"0.6341463414634146\n",
"7.694495099130147 1.0638908056581033\n",
"0.6557087310003534\n",
"7.720294625758447 1.0688068752790627\n",
"0.6334393778720396\n",
"7.754548172169112 1.073848192522783\n",
"0.6546482856132908\n",
"7.7835841037189235 1.0798652830889857\n",
"0.6472251679038529\n",
"7.813332999251804 1.0860221793068552\n",
"0.6585365853658537\n",
"7.841287614198128 1.09276365671608\n",
"0.6578296217744786\n",
"7.874772479927948 1.0997133324111505\n",
"0.6553552492046659\n",
"7.905298688711339 1.1074603986757043\n",
"0.664192294096854\n",
"7.92881123515752 1.1151154692303549\n",
"0.6670201484623541\n",
"7.956179577823236 1.1239593583109877\n",
"0.6751502297631673\n",
"7.9873016833837065 1.1311207930154308\n",
"0.6808059384941676\n",
"8.013985948469601 1.1390309756417045\n",
"0.6815129020855426\n",
"8.04583041278601 1.1470198652242665\n",
"0.6864616472251679\n",
"8.077031280288058 1.1561225883381838\n",
"0.7002474372569811\n",
"8.097823430400517 1.1633767036715974\n",
"0.6896429833863555\n",
"8.121751038254072 1.1714501937120494\n",
"0.6949452103216683\n",
"8.14295631988417 1.1802750706877245\n",
"0.6931778013432308\n",
"8.173377256448603 1.1882865250134902\n",
"0.6988335100742312\n",
"8.193755548585838 1.1982506733343627\n",
"0.6988335100742312\n",
"8.220839142648206 1.2051574876049813\n",
"0.7030752916224814\n",
"8.239019539627257 1.2125233398118578\n",
"0.6998939554612937\n",
"8.260645020711133 1.2207418254981088\n",
"0.7034287734181689\n",
"8.28629250808858 1.2246431539911202\n",
"0.6988335100742312\n",
"8.305564237841512 1.2321320291739561\n",
"0.7006009190526687\n",
"8.329160362956163 1.2388141847849186\n",
"0.7030752916224814\n",
"8.351474420302138 1.247967791993751\n",
"0.7037822552138563\n",
"8.37142298532775 1.2532609164105555\n",
"0.6995404736656062\n"
]
}
],
"source": [
"for _ in range(100):\n",
" for u in range(U):\n",
" us[u] = update(u, us, vs, u_items, U, M, user=True, lr=0.01 / np.sqrt(1+_), lamb=1.0)\n",
" for m in range(M):\n",
" vs[m] = update(m, vs, us, i_users, M, U, user=False, lr=0.01 / np.sqrt(1+_), lamb=1.0)\n",
" print(np.dot(us[0], us[0]), np.dot(vs[0], vs[0]))\n",
" x = np.dot(us, vs.T).argsort()[:, ::-1][:, :3]\n",
" ret = []\n",
" for u, seen_items in u_items.items():\n",
" q = set(x[u]).intersection(seen_items.tolist())\n",
" ret.append(len(q) / 3)\n",
" print(np.mean(ret))"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.6.8"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment