Last active
August 2, 2020 04:55
-
-
Save uni-3/559aeb67e86fb022b8480263b7608b1c to your computer and use it in GitHub Desktop.
matrix factorization with tensorflow
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"nbformat": 4, | |
"nbformat_minor": 0, | |
"metadata": { | |
"kernelspec": { | |
"display_name": "Python 3", | |
"language": "python", | |
"name": "python3" | |
}, | |
"language_info": { | |
"codemirror_mode": { | |
"name": "ipython", | |
"version": 3 | |
}, | |
"file_extension": ".py", | |
"mimetype": "text/x-python", | |
"name": "python", | |
"nbconvert_exporter": "python", | |
"pygments_lexer": "ipython3", | |
"version": "3.7.6" | |
}, | |
"colab": { | |
"name": "mf_keras.ipynb", | |
"provenance": [], | |
"collapsed_sections": [], | |
"include_colab_link": true | |
} | |
}, | |
"cells": [ | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"id": "view-in-github", | |
"colab_type": "text" | |
}, | |
"source": [ | |
"<a href=\"https://colab.research.google.com/gist/uni-3/559aeb67e86fb022b8480263b7608b1c/mf_keras.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "5wLON6eufKPP", | |
"colab_type": "code", | |
"colab": {} | |
}, | |
"source": [ | |
"import datetime\n", | |
"import numpy as np\n", | |
"import pandas as pd\n", | |
"import sklearn.model_selection\n", | |
"import tensorflow as tf\n", | |
"import matplotlib.pyplot as plt\n", | |
"import os\n", | |
"import sys" | |
], | |
"execution_count": 2, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"id": "G3HSls6VfKPu", | |
"colab_type": "text" | |
}, | |
"source": [ | |
"#### setup dataset" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "v79IwzJXfKPw", | |
"colab_type": "code", | |
"colab": { | |
"base_uri": "https://localhost:8080/", | |
"height": 526 | |
}, | |
"outputId": "ec80c737-9fc5-4a12-a199-2b6a0d31c663" | |
}, | |
"source": [ | |
"# download dataset\n", | |
"!mkdir -p data\n", | |
"!curl 'http://files.grouplens.org/datasets/movielens/ml-100k.zip' -o ./data/ml-100k.zip\n", | |
"!unzip -o ./data/ml-100k.zip -d data" | |
], | |
"execution_count": 18, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"text": [ | |
" % Total % Received % Xferd Average Speed Time Time Time Current\n", | |
" Dload Upload Total Spent Left Speed\n", | |
"\r 0 0 0 0 0 0 0 0 --:--:-- --:--:-- --:--:-- 0\r100 4808k 100 4808k 0 0 12.0M 0 --:--:-- --:--:-- --:--:-- 12.0M\n", | |
"Archive: ./data/ml-100k.zip\n", | |
" inflating: data/ml-100k/allbut.pl \n", | |
" inflating: data/ml-100k/mku.sh \n", | |
" inflating: data/ml-100k/README \n", | |
" inflating: data/ml-100k/u.data \n", | |
" inflating: data/ml-100k/u.genre \n", | |
" inflating: data/ml-100k/u.info \n", | |
" inflating: data/ml-100k/u.item \n", | |
" inflating: data/ml-100k/u.occupation \n", | |
" inflating: data/ml-100k/u.user \n", | |
" inflating: data/ml-100k/u1.base \n", | |
" inflating: data/ml-100k/u1.test \n", | |
" inflating: data/ml-100k/u2.base \n", | |
" inflating: data/ml-100k/u2.test \n", | |
" inflating: data/ml-100k/u3.base \n", | |
" inflating: data/ml-100k/u3.test \n", | |
" inflating: data/ml-100k/u4.base \n", | |
" inflating: data/ml-100k/u4.test \n", | |
" inflating: data/ml-100k/u5.base \n", | |
" inflating: data/ml-100k/u5.test \n", | |
" inflating: data/ml-100k/ua.base \n", | |
" inflating: data/ml-100k/ua.test \n", | |
" inflating: data/ml-100k/ub.base \n", | |
" inflating: data/ml-100k/ub.test \n" | |
], | |
"name": "stdout" | |
} | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "9R6dj_ctfP71", | |
"colab_type": "code", | |
"colab": { | |
"base_uri": "https://localhost:8080/", | |
"height": 36 | |
}, | |
"outputId": "d01ef12c-bb98-459c-b6f9-a1e544f723a6" | |
}, | |
"source": [ | |
"!ls data" | |
], | |
"execution_count": 19, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"text": [ | |
"ml-100k ml-100k.zip\n" | |
], | |
"name": "stdout" | |
} | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "dA0a3uwXfKP7", | |
"colab_type": "code", | |
"colab": {} | |
}, | |
"source": [ | |
"def id_to_index(df):\n", | |
" # id to index\n", | |
" user_to_index = {uid: idx for idx, uid in enumerate(set(df['user_id']))}\n", | |
" item_to_index = {iid: idx for idx, iid in enumerate(set(df['item_id']))}\n", | |
" index_to_user = {idx: uid for idx, uid in enumerate(set(df['user_id']))}\n", | |
" index_to_item = {idx: iid for idx, iid in enumerate(set(df['item_id']))}\n", | |
" \n", | |
" return user_to_index, item_to_index, \\\n", | |
" index_to_user, index_to_item\n", | |
"\n", | |
"def load_movielens(input_file=\"./data/ml-100k/u.data\") -> (pd.DataFrame, dict, dict, dict, dict):\n", | |
" headers = ['user_id', 'item_id', 'rating', 'timestamp']\n", | |
"\n", | |
"\n", | |
" df = pd.read_csv(input_file,\n", | |
" sep='\\t',\n", | |
" names=headers, \n", | |
" header=None,\n", | |
" dtype={\n", | |
" 'user_id': np.int32,\n", | |
" 'item_id': np.int32,\n", | |
" 'rating': np.float32,\n", | |
" 'timestamp': np.int32,\n", | |
" })\n", | |
"\n", | |
" # id to index\n", | |
" user_to_index, item_to_index, \\\n", | |
" index_to_user, index_to_item = id_to_index(df)\n", | |
"\n", | |
" # add index col\n", | |
" df['user_id_index'] = df['user_id'].map(user_to_index)\n", | |
" df['item_id_index'] = df['item_id'].map(item_to_index)\n", | |
" return df, user_to_index, item_to_index,\\\n", | |
" index_to_user, index_to_item\n" | |
], | |
"execution_count": 20, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "0SFz5z2bfKQD", | |
"colab_type": "code", | |
"colab": {} | |
}, | |
"source": [ | |
"def prep_data(df, col_names=['user_id', 'item_id'], target='rating', norm=True, n_nega_sample: int=None) -> (pd.DataFrame, int, int):\n", | |
" \"\"\"\n", | |
"\n", | |
" Parameters\n", | |
" ----------\n", | |
" df: pd.DataFrame\n", | |
" [user_id, item_id, score]\n", | |
" col_names: list[string]\n", | |
" name of user_id and item_id\n", | |
" target: string\n", | |
" score\n", | |
" norm: bool\n", | |
" whether normalize rating or not\n", | |
" n_nega_sample: int\n", | |
" number of negative sampling from dataset\n", | |
" if None do not sampling\n", | |
" Returns\n", | |
" -------\n", | |
"\n", | |
" \"\"\"\n", | |
"\n", | |
" if norm:\n", | |
" df[target] = 1\n", | |
" n_user = df[col_names[0]].nunique()\n", | |
" n_item = df[col_names[1]].nunique()\n", | |
"\n", | |
" df_mat = pd.pivot_table(df, index=col_names[0],\n", | |
" columns=col_names[1], values=target, fill_value=0\n", | |
" ).stack().reset_index().rename(columns={0:'rating'})\n", | |
"\n", | |
" # positive samples\n", | |
" df_posi = df_mat[df_mat['rating']>0].reset_index(drop=True)\n", | |
" df_com = df_posi\n", | |
"\n", | |
" if n_nega_sample is not None:\n", | |
" df_nega = df_mat[df_mat['rating']==0].reset_index(drop=True)\n", | |
"\n", | |
" df_nega_samples = df_nega.groupby(col_names[0]) \\\n", | |
" .apply(lambda x: x.sample(n=min(n_nega_sample, len(x)), random_state=14)) \\\n", | |
" .reset_index(drop=True)\n", | |
"\n", | |
" df_com = pd.concat([df_posi, df_nega_samples]).reset_index(drop=True)\n", | |
"\n", | |
" return df_com, n_user, n_item\n" | |
], | |
"execution_count": 21, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "EUvuur09fKQM", | |
"colab_type": "code", | |
"colab": {} | |
}, | |
"source": [ | |
"df_raw, user_to_index, item_to_index,\\\n", | |
" index_to_user, index_to_item = load_movielens()\n", | |
"df_raw.shape\n", | |
"# prep data\n", | |
"col_names = ['user_id_index', 'item_id_index']\n", | |
"target = 'rating'\n", | |
"df, n_user, n_item = prep_data(df_raw,\n", | |
" #col_names=col_names,\n", | |
" n_nega_sample=5, norm=None\n", | |
" )\n", | |
"#df = df.astype({'user_id_index': 'int', 'user_id_index': 'int'})" | |
], | |
"execution_count": 22, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "8GHGqf3lfKQY", | |
"colab_type": "code", | |
"colab": { | |
"base_uri": "https://localhost:8080/", | |
"height": 206 | |
}, | |
"outputId": "4699eab1-19bd-4a56-d26e-e9688fcfc6ff" | |
}, | |
"source": [ | |
"df_raw.head()" | |
], | |
"execution_count": 23, | |
"outputs": [ | |
{ | |
"output_type": "execute_result", | |
"data": { | |
"text/html": [ | |
"<div>\n", | |
"<style scoped>\n", | |
" .dataframe tbody tr th:only-of-type {\n", | |
" vertical-align: middle;\n", | |
" }\n", | |
"\n", | |
" .dataframe tbody tr th {\n", | |
" vertical-align: top;\n", | |
" }\n", | |
"\n", | |
" .dataframe thead th {\n", | |
" text-align: right;\n", | |
" }\n", | |
"</style>\n", | |
"<table border=\"1\" class=\"dataframe\">\n", | |
" <thead>\n", | |
" <tr style=\"text-align: right;\">\n", | |
" <th></th>\n", | |
" <th>user_id</th>\n", | |
" <th>item_id</th>\n", | |
" <th>rating</th>\n", | |
" <th>timestamp</th>\n", | |
" <th>user_id_index</th>\n", | |
" <th>item_id_index</th>\n", | |
" </tr>\n", | |
" </thead>\n", | |
" <tbody>\n", | |
" <tr>\n", | |
" <th>0</th>\n", | |
" <td>196</td>\n", | |
" <td>242</td>\n", | |
" <td>3.0</td>\n", | |
" <td>881250949</td>\n", | |
" <td>195</td>\n", | |
" <td>241</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>1</th>\n", | |
" <td>186</td>\n", | |
" <td>302</td>\n", | |
" <td>3.0</td>\n", | |
" <td>891717742</td>\n", | |
" <td>185</td>\n", | |
" <td>301</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>2</th>\n", | |
" <td>22</td>\n", | |
" <td>377</td>\n", | |
" <td>1.0</td>\n", | |
" <td>878887116</td>\n", | |
" <td>21</td>\n", | |
" <td>376</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>3</th>\n", | |
" <td>244</td>\n", | |
" <td>51</td>\n", | |
" <td>2.0</td>\n", | |
" <td>880606923</td>\n", | |
" <td>243</td>\n", | |
" <td>50</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>4</th>\n", | |
" <td>166</td>\n", | |
" <td>346</td>\n", | |
" <td>1.0</td>\n", | |
" <td>886397596</td>\n", | |
" <td>165</td>\n", | |
" <td>345</td>\n", | |
" </tr>\n", | |
" </tbody>\n", | |
"</table>\n", | |
"</div>" | |
], | |
"text/plain": [ | |
" user_id item_id rating timestamp user_id_index item_id_index\n", | |
"0 196 242 3.0 881250949 195 241\n", | |
"1 186 302 3.0 891717742 185 301\n", | |
"2 22 377 1.0 878887116 21 376\n", | |
"3 244 51 2.0 880606923 243 50\n", | |
"4 166 346 1.0 886397596 165 345" | |
] | |
}, | |
"metadata": { | |
"tags": [] | |
}, | |
"execution_count": 23 | |
} | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "u-G9piITfKQh", | |
"colab_type": "code", | |
"colab": { | |
"base_uri": "https://localhost:8080/", | |
"height": 206 | |
}, | |
"outputId": "e2c4c6b0-b9de-4be0-9f05-153715552961" | |
}, | |
"source": [ | |
"df.head()" | |
], | |
"execution_count": 24, | |
"outputs": [ | |
{ | |
"output_type": "execute_result", | |
"data": { | |
"text/html": [ | |
"<div>\n", | |
"<style scoped>\n", | |
" .dataframe tbody tr th:only-of-type {\n", | |
" vertical-align: middle;\n", | |
" }\n", | |
"\n", | |
" .dataframe tbody tr th {\n", | |
" vertical-align: top;\n", | |
" }\n", | |
"\n", | |
" .dataframe thead th {\n", | |
" text-align: right;\n", | |
" }\n", | |
"</style>\n", | |
"<table border=\"1\" class=\"dataframe\">\n", | |
" <thead>\n", | |
" <tr style=\"text-align: right;\">\n", | |
" <th></th>\n", | |
" <th>user_id</th>\n", | |
" <th>item_id</th>\n", | |
" <th>rating</th>\n", | |
" </tr>\n", | |
" </thead>\n", | |
" <tbody>\n", | |
" <tr>\n", | |
" <th>0</th>\n", | |
" <td>1</td>\n", | |
" <td>1</td>\n", | |
" <td>5</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>1</th>\n", | |
" <td>1</td>\n", | |
" <td>2</td>\n", | |
" <td>3</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>2</th>\n", | |
" <td>1</td>\n", | |
" <td>3</td>\n", | |
" <td>4</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>3</th>\n", | |
" <td>1</td>\n", | |
" <td>4</td>\n", | |
" <td>3</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>4</th>\n", | |
" <td>1</td>\n", | |
" <td>5</td>\n", | |
" <td>3</td>\n", | |
" </tr>\n", | |
" </tbody>\n", | |
"</table>\n", | |
"</div>" | |
], | |
"text/plain": [ | |
" user_id item_id rating\n", | |
"0 1 1 5\n", | |
"1 1 2 3\n", | |
"2 1 3 4\n", | |
"3 1 4 3\n", | |
"4 1 5 3" | |
] | |
}, | |
"metadata": { | |
"tags": [] | |
}, | |
"execution_count": 24 | |
} | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "fsINgd_RfKQo", | |
"colab_type": "code", | |
"colab": {} | |
}, | |
"source": [ | |
"X_train, X_test, y_train, y_test \\\n", | |
" = sklearn.model_selection.train_test_split(df[['user_id', 'item_id']], df[target]\n", | |
" , test_size=0.05, random_state=1)\n", | |
" " | |
], | |
"execution_count": 25, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "ARWEtqnYfKQw", | |
"colab_type": "code", | |
"colab": { | |
"base_uri": "https://localhost:8080/", | |
"height": 36 | |
}, | |
"outputId": "10c48f0b-c529-4f9e-8f45-87b8a3c34c84" | |
}, | |
"source": [ | |
"X_train.shape" | |
], | |
"execution_count": 26, | |
"outputs": [ | |
{ | |
"output_type": "execute_result", | |
"data": { | |
"text/plain": [ | |
"(99479, 2)" | |
] | |
}, | |
"metadata": { | |
"tags": [] | |
}, | |
"execution_count": 26 | |
} | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "2le9n0mSfKQ7", | |
"colab_type": "code", | |
"colab": { | |
"base_uri": "https://localhost:8080/", | |
"height": 36 | |
}, | |
"outputId": "651ec478-4b2e-431d-bca5-66023f2f69ba" | |
}, | |
"source": [ | |
"X_test.shape" | |
], | |
"execution_count": 27, | |
"outputs": [ | |
{ | |
"output_type": "execute_result", | |
"data": { | |
"text/plain": [ | |
"(5236, 2)" | |
] | |
}, | |
"metadata": { | |
"tags": [] | |
}, | |
"execution_count": 27 | |
} | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "fXJYnroCfKRG", | |
"colab_type": "code", | |
"colab": { | |
"base_uri": "https://localhost:8080/", | |
"height": 36 | |
}, | |
"outputId": "77ee3de1-3f36-4f05-f30b-04e74f7a3591" | |
}, | |
"source": [ | |
"X_train['user_id'].nunique()" | |
], | |
"execution_count": 28, | |
"outputs": [ | |
{ | |
"output_type": "execute_result", | |
"data": { | |
"text/plain": [ | |
"943" | |
] | |
}, | |
"metadata": { | |
"tags": [] | |
}, | |
"execution_count": 28 | |
} | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "LKeZLlLVfKRN", | |
"colab_type": "code", | |
"colab": { | |
"base_uri": "https://localhost:8080/", | |
"height": 36 | |
}, | |
"outputId": "7811ae25-e96e-46f3-9ad9-4a6ab01a0f3e" | |
}, | |
"source": [ | |
"X_test['user_id'].nunique()" | |
], | |
"execution_count": 29, | |
"outputs": [ | |
{ | |
"output_type": "execute_result", | |
"data": { | |
"text/plain": [ | |
"877" | |
] | |
}, | |
"metadata": { | |
"tags": [] | |
}, | |
"execution_count": 29 | |
} | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "G7C06ajwfKRZ", | |
"colab_type": "code", | |
"colab": {} | |
}, | |
"source": [ | |
"user_id_input = tf.keras.layers.Input(shape=[1], name='user')\n", | |
"item_id_input = tf.keras.layers.Input(shape=[1], name='item')\n", | |
"\n", | |
"\n", | |
"embedding_size = 8\n", | |
"user_embedding = tf.keras.layers.Embedding(output_dim=embedding_size, input_dim=n_user+1,\n", | |
" #embeddings_initializer='glorot_uniform',\n", | |
" #embeddings_regularizer=tf.keras.regularizers.l2(l=0.01),\n", | |
" input_length=1,\n", | |
" name='user_embedding')(user_id_input)\n", | |
"item_embedding = tf.keras.layers.Embedding(output_dim=embedding_size, input_dim=n_item+1,\n", | |
" #embeddings_initializer='glorot_uniform',\n", | |
" #embeddings_regularizer=tf.keras.regularizers.l2(l=0.01),\n", | |
" input_length=1,\n", | |
" name='item_embedding')(item_id_input)\n", | |
"\n", | |
"user_vecs = tf.keras.layers.Reshape([embedding_size])(user_embedding)\n", | |
"item_vecs = tf.keras.layers.Reshape([embedding_size])(item_embedding)\n", | |
"\n", | |
"# Final prediction layer\n", | |
"# inner product(score)\n", | |
"y = tf.keras.layers.Dot(1, normalize=False, name=\"concat\")([user_vecs, item_vecs])\n", | |
"#y = tf.linalg.matmul([user_vecs, item_vecs])\n", | |
"\n", | |
"model = tf.keras.Model(inputs=[user_id_input, item_id_input], outputs=y)\n", | |
"\n", | |
"model.compile(loss='mse',\n", | |
" optimizer=\"adam\"\n", | |
" )" | |
], | |
"execution_count": 30, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "h6tqn0vWfKRj", | |
"colab_type": "code", | |
"colab": { | |
"base_uri": "https://localhost:8080/", | |
"height": 74 | |
}, | |
"outputId": "a4a60d7b-490b-4bf3-cc86-69bec60b2995" | |
}, | |
"source": [ | |
"!pip install pydot-ng graphviz " | |
], | |
"execution_count": 31, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"text": [ | |
"Requirement already satisfied: pydot-ng in /usr/local/lib/python3.6/dist-packages (2.0.0)\n", | |
"Requirement already satisfied: graphviz in /usr/local/lib/python3.6/dist-packages (0.10.1)\n", | |
"Requirement already satisfied: pyparsing>=2.0.1 in /usr/local/lib/python3.6/dist-packages (from pydot-ng) (2.4.7)\n" | |
], | |
"name": "stdout" | |
} | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "pDfJmM3_fKRr", | |
"colab_type": "code", | |
"colab": { | |
"base_uri": "https://localhost:8080/", | |
"height": 369 | |
}, | |
"outputId": "537daf61-f6be-4bd1-9f81-0088e8be180e" | |
}, | |
"source": [ | |
"# show model\n", | |
"tf.keras.utils.plot_model(\n", | |
" model, #to_file='model.png',\n", | |
" show_shapes=False, show_layer_names=True,\n", | |
" rankdir='TB', expand_nested=True, dpi=96\n", | |
")\n" | |
], | |
"execution_count": 32, | |
"outputs": [ | |
{ | |
"output_type": "execute_result", | |
"data": { | |
"image/png": "\n", | |
"text/plain": [ | |
"<IPython.core.display.Image object>" | |
] | |
}, | |
"metadata": { | |
"tags": [] | |
}, | |
"execution_count": 32 | |
} | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "sviUZ9_UfKRx", | |
"colab_type": "code", | |
"colab": { | |
"base_uri": "https://localhost:8080/", | |
"height": 451 | |
}, | |
"outputId": "398f354f-85c0-4334-889e-da48b23265fa" | |
}, | |
"source": [ | |
"model.summary()" | |
], | |
"execution_count": 33, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"text": [ | |
"Model: \"model\"\n", | |
"__________________________________________________________________________________________________\n", | |
"Layer (type) Output Shape Param # Connected to \n", | |
"==================================================================================================\n", | |
"user (InputLayer) [(None, 1)] 0 \n", | |
"__________________________________________________________________________________________________\n", | |
"item (InputLayer) [(None, 1)] 0 \n", | |
"__________________________________________________________________________________________________\n", | |
"user_embedding (Embedding) (None, 1, 8) 7552 user[0][0] \n", | |
"__________________________________________________________________________________________________\n", | |
"item_embedding (Embedding) (None, 1, 8) 13464 item[0][0] \n", | |
"__________________________________________________________________________________________________\n", | |
"reshape (Reshape) (None, 8) 0 user_embedding[0][0] \n", | |
"__________________________________________________________________________________________________\n", | |
"reshape_1 (Reshape) (None, 8) 0 item_embedding[0][0] \n", | |
"__________________________________________________________________________________________________\n", | |
"concat (Dot) (None, 1) 0 reshape[0][0] \n", | |
" reshape_1[0][0] \n", | |
"==================================================================================================\n", | |
"Total params: 21,016\n", | |
"Trainable params: 21,016\n", | |
"Non-trainable params: 0\n", | |
"__________________________________________________________________________________________________\n" | |
], | |
"name": "stdout" | |
} | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "2kQ3kZPtfKR5", | |
"colab_type": "code", | |
"colab": { | |
"base_uri": "https://localhost:8080/", | |
"height": 1000 | |
}, | |
"outputId": "3dce9058-1bac-4195-ee2a-0f11f9055af8" | |
}, | |
"source": [ | |
"import time\n", | |
"\n", | |
"checkpoint_callback = tf.keras.callbacks.ModelCheckpoint('./models_fm', \n", | |
" monitor='val_loss',\n", | |
" save_best_only=True,\n", | |
" freq='epoch'\n", | |
" )\n", | |
"log_dir = os.path.join(\"./models_fm/logs\" , datetime.datetime.now().strftime(\"%Y%m%d-%H%M%S\"))\n", | |
"tensorboard_callback = tf.keras.callbacks.TensorBoard(log_dir=log_dir, histogram_freq=1)\n", | |
"\n", | |
"history = model.fit(x=[X_train[\"user_id\"], X_train[\"item_id\"]]\n", | |
" ,y=y_train\n", | |
" ,batch_size=128, epochs=20\n", | |
" ,validation_split=0.1\n", | |
" ,callbacks=[checkpoint_callback, tensorboard_callback]\n", | |
" ,shuffle=True\n", | |
" #,verbose=2\n", | |
")\n", | |
"\n", | |
"import pickle\n", | |
"with open(os.path.join(\"models_fm\", \"history.pkl\"), \"wb\") as f:\n", | |
" pickle.dump(history.history, f)" | |
], | |
"execution_count": 34, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"text": [ | |
"Epoch 1/20\n", | |
"699/700 [============================>.] - ETA: 0s - loss: 12.4237WARNING:tensorflow:From /usr/local/lib/python3.6/dist-packages/tensorflow/python/ops/resource_variable_ops.py:1817: calling BaseResourceVariable.__init__ (from tensorflow.python.ops.resource_variable_ops) with constraint is deprecated and will be removed in a future version.\n", | |
"Instructions for updating:\n", | |
"If using Keras pass *_constraint arguments to layers.\n", | |
"INFO:tensorflow:Assets written to: ./models_fm/assets\n", | |
"700/700 [==============================] - 2s 3ms/step - loss: 12.4219 - val_loss: 9.6668\n", | |
"Epoch 2/20\n", | |
"681/700 [============================>.] - ETA: 0s - loss: 5.5264INFO:tensorflow:Assets written to: ./models_fm/assets\n", | |
"700/700 [==============================] - 2s 3ms/step - loss: 5.4549 - val_loss: 2.7801\n", | |
"Epoch 3/20\n", | |
"682/700 [============================>.] - ETA: 0s - loss: 2.0878INFO:tensorflow:Assets written to: ./models_fm/assets\n", | |
"700/700 [==============================] - 2s 3ms/step - loss: 2.0768 - val_loss: 1.6733\n", | |
"Epoch 4/20\n", | |
"697/700 [============================>.] - ETA: 0s - loss: 1.4977INFO:tensorflow:Assets written to: ./models_fm/assets\n", | |
"700/700 [==============================] - 2s 2ms/step - loss: 1.4976 - val_loss: 1.3814\n", | |
"Epoch 5/20\n", | |
"699/700 [============================>.] - ETA: 0s - loss: 1.3131INFO:tensorflow:Assets written to: ./models_fm/assets\n", | |
"700/700 [==============================] - 2s 3ms/step - loss: 1.3134 - val_loss: 1.2714\n", | |
"Epoch 6/20\n", | |
"698/700 [============================>.] - ETA: 0s - loss: 1.2366INFO:tensorflow:Assets written to: ./models_fm/assets\n", | |
"700/700 [==============================] - 2s 2ms/step - loss: 1.2371 - val_loss: 1.2249\n", | |
"Epoch 7/20\n", | |
"692/700 [============================>.] - ETA: 0s - loss: 1.2012INFO:tensorflow:Assets written to: ./models_fm/assets\n", | |
"700/700 [==============================] - 2s 2ms/step - loss: 1.2008 - val_loss: 1.2033\n", | |
"Epoch 8/20\n", | |
"675/700 [===========================>..] - ETA: 0s - loss: 1.1818INFO:tensorflow:Assets written to: ./models_fm/assets\n", | |
"700/700 [==============================] - 2s 2ms/step - loss: 1.1818 - val_loss: 1.1919\n", | |
"Epoch 9/20\n", | |
"675/700 [===========================>..] - ETA: 0s - loss: 1.1696INFO:tensorflow:Assets written to: ./models_fm/assets\n", | |
"700/700 [==============================] - 2s 3ms/step - loss: 1.1706 - val_loss: 1.1882\n", | |
"Epoch 10/20\n", | |
"678/700 [============================>.] - ETA: 0s - loss: 1.1647INFO:tensorflow:Assets written to: ./models_fm/assets\n", | |
"700/700 [==============================] - 2s 3ms/step - loss: 1.1639 - val_loss: 1.1855\n", | |
"Epoch 11/20\n", | |
"700/700 [==============================] - 1s 2ms/step - loss: 1.1593 - val_loss: 1.1859\n", | |
"Epoch 12/20\n", | |
"695/700 [============================>.] - ETA: 0s - loss: 1.1563INFO:tensorflow:Assets written to: ./models_fm/assets\n", | |
"700/700 [==============================] - 2s 3ms/step - loss: 1.1562 - val_loss: 1.1836\n", | |
"Epoch 13/20\n", | |
"693/700 [============================>.] - ETA: 0s - loss: 1.1534INFO:tensorflow:Assets written to: ./models_fm/assets\n", | |
"700/700 [==============================] - 2s 3ms/step - loss: 1.1537 - val_loss: 1.1825\n", | |
"Epoch 14/20\n", | |
"700/700 [==============================] - 1s 2ms/step - loss: 1.1516 - val_loss: 1.1844\n", | |
"Epoch 15/20\n", | |
"675/700 [===========================>..] - ETA: 0s - loss: 1.1484INFO:tensorflow:Assets written to: ./models_fm/assets\n", | |
"700/700 [==============================] - 2s 3ms/step - loss: 1.1494 - val_loss: 1.1810\n", | |
"Epoch 16/20\n", | |
"700/700 [==============================] - 1s 2ms/step - loss: 1.1475 - val_loss: 1.1822\n", | |
"Epoch 17/20\n", | |
"700/700 [==============================] - 1s 2ms/step - loss: 1.1457 - val_loss: 1.1812\n", | |
"Epoch 18/20\n", | |
"700/700 [==============================] - 1s 2ms/step - loss: 1.1438 - val_loss: 1.1815\n", | |
"Epoch 19/20\n", | |
"700/700 [==============================] - 1s 2ms/step - loss: 1.1416 - val_loss: 1.1813\n", | |
"Epoch 20/20\n", | |
"681/700 [============================>.] - ETA: 0s - loss: 1.1383INFO:tensorflow:Assets written to: ./models_fm/assets\n", | |
"700/700 [==============================] - 2s 2ms/step - loss: 1.1389 - val_loss: 1.1804\n" | |
], | |
"name": "stdout" | |
} | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "TKC1qsN-fKSA", | |
"colab_type": "code", | |
"colab": {} | |
}, | |
"source": [ | |
"import pickle\n", | |
"with open(os.path.join(\"models_fm\", \"history.pkl\"), \"wb\") as f:\n", | |
" pickle.dump(history.history, f)" | |
], | |
"execution_count": 35, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "1ZsBmNlOfKSH", | |
"colab_type": "code", | |
"colab": { | |
"base_uri": "https://localhost:8080/", | |
"height": 36 | |
}, | |
"outputId": "f563fad7-bb19-4554-e39e-4a4abd29236f" | |
}, | |
"source": [ | |
"model.layers[3].get_weights()[0].shape" | |
], | |
"execution_count": 36, | |
"outputs": [ | |
{ | |
"output_type": "execute_result", | |
"data": { | |
"text/plain": [ | |
"(1683, 8)" | |
] | |
}, | |
"metadata": { | |
"tags": [] | |
}, | |
"execution_count": 36 | |
} | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "DWVtifB6fKSM", | |
"colab_type": "code", | |
"colab": { | |
"base_uri": "https://localhost:8080/", | |
"height": 36 | |
}, | |
"outputId": "ddadb22e-c25b-48c1-cc26-18c9fa3b381d" | |
}, | |
"source": [ | |
"model.layers[2].get_weights()[0].shape" | |
], | |
"execution_count": 37, | |
"outputs": [ | |
{ | |
"output_type": "execute_result", | |
"data": { | |
"text/plain": [ | |
"(944, 8)" | |
] | |
}, | |
"metadata": { | |
"tags": [] | |
}, | |
"execution_count": 37 | |
} | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "w2IDZ5FEfKSW", | |
"colab_type": "code", | |
"colab": {} | |
}, | |
"source": [ | |
"# output embeddings\n", | |
"pd.DataFrame(model.layers[2].get_weights()[0]).to_csv('./user_emb_8.csv')\n", | |
"pd.DataFrame(model.layers[3].get_weights()[0]).to_csv('./item_emb_8.csv')\n" | |
], | |
"execution_count": 38, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "eoL-CMGWfKSg", | |
"colab_type": "code", | |
"colab": { | |
"base_uri": "https://localhost:8080/", | |
"height": 206 | |
}, | |
"outputId": "40f783ea-09c6-4289-e803-827f48b68168" | |
}, | |
"source": [ | |
"df_scores = pd.DataFrame(model.layers[2].get_weights()[0] @ model.layers[3].get_weights()[0].T).stack().reset_index()\n", | |
"df_scores.columns = ['user_id', 'item_id', 'rating']\n", | |
"df_scores.to_csv('./scores.csv', index=False)\n", | |
"df_scores.head()\n" | |
], | |
"execution_count": 39, | |
"outputs": [ | |
{ | |
"output_type": "execute_result", | |
"data": { | |
"text/html": [ | |
"<div>\n", | |
"<style scoped>\n", | |
" .dataframe tbody tr th:only-of-type {\n", | |
" vertical-align: middle;\n", | |
" }\n", | |
"\n", | |
" .dataframe tbody tr th {\n", | |
" vertical-align: top;\n", | |
" }\n", | |
"\n", | |
" .dataframe thead th {\n", | |
" text-align: right;\n", | |
" }\n", | |
"</style>\n", | |
"<table border=\"1\" class=\"dataframe\">\n", | |
" <thead>\n", | |
" <tr style=\"text-align: right;\">\n", | |
" <th></th>\n", | |
" <th>user_id</th>\n", | |
" <th>item_id</th>\n", | |
" <th>rating</th>\n", | |
" </tr>\n", | |
" </thead>\n", | |
" <tbody>\n", | |
" <tr>\n", | |
" <th>0</th>\n", | |
" <td>0</td>\n", | |
" <td>0</td>\n", | |
" <td>0.001375</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>1</th>\n", | |
" <td>0</td>\n", | |
" <td>1</td>\n", | |
" <td>-0.060442</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>2</th>\n", | |
" <td>0</td>\n", | |
" <td>2</td>\n", | |
" <td>-0.048946</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>3</th>\n", | |
" <td>0</td>\n", | |
" <td>3</td>\n", | |
" <td>-0.048313</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>4</th>\n", | |
" <td>0</td>\n", | |
" <td>4</td>\n", | |
" <td>-0.050197</td>\n", | |
" </tr>\n", | |
" </tbody>\n", | |
"</table>\n", | |
"</div>" | |
], | |
"text/plain": [ | |
" user_id item_id rating\n", | |
"0 0 0 0.001375\n", | |
"1 0 1 -0.060442\n", | |
"2 0 2 -0.048946\n", | |
"3 0 3 -0.048313\n", | |
"4 0 4 -0.050197" | |
] | |
}, | |
"metadata": { | |
"tags": [] | |
}, | |
"execution_count": 39 | |
} | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "TSybDjVyfKSm", | |
"colab_type": "code", | |
"colab": {} | |
}, | |
"source": [ | |
"# umap\n", | |
"import pandas as pd\n", | |
"import umap\n", | |
"import time\n", | |
"import matplotlib.pyplot as plt\n", | |
"\n", | |
"#df_user_emb = pd.read_csv('./user_emb_8.csv')\n", | |
"#df_item_emb = pd.read_csv('./item_emb_8.csv')\n", | |
"#embedding = umap.UMAP().fit_transform(df_item_emb[[\"0\",\"1\",\"2\",\"3\",\"4\",\"5\",\"6\",\"7\"]].values)\n", | |
"#pd.DataFrame(embedding).to_csv('item_emb_umap.csv')" | |
], | |
"execution_count": 40, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "fCTc45hqfKSs", | |
"colab_type": "code", | |
"colab": { | |
"base_uri": "https://localhost:8080/", | |
"height": 36 | |
}, | |
"outputId": "fec48ccf-bb58-454e-80a5-cdbf9161354d" | |
}, | |
"source": [ | |
"pd.DataFrame(model.layers[2].get_weights()[0] @ model.layers[3].get_weights()[0].T).shape" | |
], | |
"execution_count": 41, | |
"outputs": [ | |
{ | |
"output_type": "execute_result", | |
"data": { | |
"text/plain": [ | |
"(944, 1683)" | |
] | |
}, | |
"metadata": { | |
"tags": [] | |
}, | |
"execution_count": 41 | |
} | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"id": "Q3bJS9QofKSx", | |
"colab_type": "text" | |
}, | |
"source": [ | |
"#### https://www.tensorflow.org/tensorboard/tensorboard_in_notebooks" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "vABcwTIifKSy", | |
"colab_type": "code", | |
"colab": {} | |
}, | |
"source": [ | |
"" | |
], | |
"execution_count": 41, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "KXE5j0eFfKS4", | |
"colab_type": "code", | |
"colab": {} | |
}, | |
"source": [ | |
"!kill 437" | |
], | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "X_vFtJ3BfKS8", | |
"colab_type": "code", | |
"colab": {}, | |
"outputId": "352afdaf-e71b-4bc5-8653-a528205e8142" | |
}, | |
"source": [ | |
"from tensorboard import notebook\n", | |
"notebook.list() # View open TensorBoard instances" | |
], | |
"execution_count": null, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"text": [ | |
"No known TensorBoard instances running.\n" | |
], | |
"name": "stdout" | |
} | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "A7gnzyCCfKTB", | |
"colab_type": "code", | |
"colab": {} | |
}, | |
"source": [ | |
"" | |
], | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "shMtu5fefKTG", | |
"colab_type": "code", | |
"colab": {} | |
}, | |
"source": [ | |
"%load_ext tensorboard" | |
], | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "Wrn5VyNEfKTL", | |
"colab_type": "code", | |
"colab": {}, | |
"outputId": "aa89e246-ccf7-4318-d239-44f6280bb1d6" | |
}, | |
"source": [ | |
"%reload_ext tensorboard\n", | |
"%tensorboard --logdir ./models_fm/logs --host 0.0.0.0" | |
], | |
"execution_count": null, | |
"outputs": [ | |
{ | |
"output_type": "display_data", | |
"data": { | |
"text/html": [ | |
"\n", | |
" <iframe id=\"tensorboard-frame-8722f0df074f4d28\" width=\"100%\" height=\"800\" frameborder=\"0\">\n", | |
" </iframe>\n", | |
" <script>\n", | |
" (function() {\n", | |
" const frame = document.getElementById(\"tensorboard-frame-8722f0df074f4d28\");\n", | |
" const url = new URL(\"/\", window.location);\n", | |
" url.port = 6006;\n", | |
" frame.src = url;\n", | |
" })();\n", | |
" </script>\n", | |
" " | |
], | |
"text/plain": [ | |
"<IPython.core.display.HTML object>" | |
] | |
}, | |
"metadata": { | |
"tags": [] | |
} | |
} | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "anW_lEhufKTO", | |
"colab_type": "code", | |
"colab": {} | |
}, | |
"source": [ | |
"y_pred = model.predict(x=[X_test[\"user_id\"], X_test[\"item_id\"]])\n" | |
], | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "8eJWWrGSfKTU", | |
"colab_type": "code", | |
"colab": {}, | |
"outputId": "8f3ed7a7-9686-4b9b-a29f-cb8803b09f97" | |
}, | |
"source": [ | |
"y_pred.shape" | |
], | |
"execution_count": null, | |
"outputs": [ | |
{ | |
"output_type": "execute_result", | |
"data": { | |
"text/plain": [ | |
"(20000, 1)" | |
] | |
}, | |
"metadata": { | |
"tags": [] | |
}, | |
"execution_count": 27 | |
} | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "qNN3BO0IfKTa", | |
"colab_type": "code", | |
"colab": {}, | |
"outputId": "76b03505-0c3b-4a16-95a0-369ad3257b46" | |
}, | |
"source": [ | |
"print(f'''user {X_test[\"user_id\"].values[0]}のitem {X_test[\"item_id\"].values[0]}に対する\\n\n", | |
"予測値: {y_pred[0][0]}\\n\n", | |
"真の値: {y_test.values[0]}\n", | |
"''')" | |
], | |
"execution_count": null, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"text": [ | |
"user 265のitem 284に対する\n", | |
"\n", | |
"予測値: 3.259321689605713\n", | |
"\n", | |
"真の値: 4\n", | |
"\n" | |
], | |
"name": "stdout" | |
} | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "EIDZr1JtfKTf", | |
"colab_type": "code", | |
"colab": {} | |
}, | |
"source": [ | |
"from sklearn.metrics import mean_squared_error\n", | |
"# calc RMSE\n", | |
"np.sqrt(mean_squared_error(y_test, y_pred))" | |
], | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "sSehlno-fKTm", | |
"colab_type": "code", | |
"colab": {}, | |
"outputId": "c0e12881-2498-41cb-e625-2881b2c3d1f3" | |
}, | |
"source": [ | |
"# 重みとオプティマイザを含む全く同じモデルを再作成\n", | |
"model_s = './models_fm'\n", | |
"model_l = tf.keras.models.load_model(model_s)\n", | |
"\n", | |
"model_l.summary()\n" | |
], | |
"execution_count": null, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"text": [ | |
"Model: \"model\"\n", | |
"__________________________________________________________________________________________________\n", | |
"Layer (type) Output Shape Param # Connected to \n", | |
"==================================================================================================\n", | |
"user (InputLayer) [(None, 1)] 0 \n", | |
"__________________________________________________________________________________________________\n", | |
"item (InputLayer) [(None, 1)] 0 \n", | |
"__________________________________________________________________________________________________\n", | |
"user_embedding (Embedding) (None, 1, 30) 2400000 user[0][0] \n", | |
"__________________________________________________________________________________________________\n", | |
"item_embedding (Embedding) (None, 1, 30) 2400000 item[0][0] \n", | |
"__________________________________________________________________________________________________\n", | |
"reshape (Reshape) (None, 30) 0 user_embedding[0][0] \n", | |
"__________________________________________________________________________________________________\n", | |
"reshape_1 (Reshape) (None, 30) 0 item_embedding[0][0] \n", | |
"__________________________________________________________________________________________________\n", | |
"dot (Dot) (None, 1) 0 reshape[0][0] \n", | |
" reshape_1[0][0] \n", | |
"==================================================================================================\n", | |
"Total params: 4,800,000\n", | |
"Trainable params: 4,800,000\n", | |
"Non-trainable params: 0\n", | |
"__________________________________________________________________________________________________\n" | |
], | |
"name": "stdout" | |
} | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "dZ_qj4ogfKTq", | |
"colab_type": "code", | |
"colab": {} | |
}, | |
"source": [ | |
"y_pred = model_l.predict(x=[X_test[\"user_id\"], X_test[\"item_id\"]])\n" | |
], | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "JwUAGj8SfKTt", | |
"colab_type": "code", | |
"colab": {}, | |
"outputId": "43a2fd06-3a40-4373-cf2c-d8d41761e86e" | |
}, | |
"source": [ | |
"from sklearn.metrics import mean_squared_error\n", | |
"# calc RMSE\n", | |
"np.sqrt(mean_squared_error(y_test, y_pred))" | |
], | |
"execution_count": null, | |
"outputs": [ | |
{ | |
"output_type": "execute_result", | |
"data": { | |
"text/plain": [ | |
"0.93278754" | |
] | |
}, | |
"metadata": { | |
"tags": [] | |
}, | |
"execution_count": 28 | |
} | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"id": "fbltKPwrfKTy", | |
"colab_type": "text" | |
}, | |
"source": [ | |
"### predict topn" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "97LJTYJ3fKTz", | |
"colab_type": "code", | |
"colab": {} | |
}, | |
"source": [ | |
"import numpy as np\n", | |
"import pandas as pd\n", | |
"import tensorflow as tf\n", | |
"\n", | |
"import itertools\n", | |
"\n", | |
"def all(model: tf.keras.Model, users: list, items: list, n: int=None):\n", | |
" \"\"\"\n", | |
" 全組み合わせについてスコアを取得\n", | |
"\n", | |
" Parameters\n", | |
" ----------\n", | |
" model: tf.keras.Model\n", | |
" user: list\n", | |
" items: list\n", | |
" n: int, default None\n", | |
" if is not None, return topn each user\n", | |
"\n", | |
" Returns\n", | |
" -------\n", | |
" df\n", | |
"\n", | |
" \"\"\"\n", | |
" # 全組み合わせ\n", | |
" ul = np.array([u for u, _ in itertools.product(list(set(users)),\n", | |
" list(set(items)))])\n", | |
" il = np.array([i for _, i in itertools.product(list(set(users)),\n", | |
" list(set(items)))])\n", | |
" y_pred = model.predict(x=[il, ul]).flatten()\n", | |
"\n", | |
" df = pd.DataFrame(\n", | |
" {\n", | |
" \"user_id_index\": ul,\n", | |
" \"item_id_index\": il,\n", | |
" \"score\": np.array(y_pred)\n", | |
" }\n", | |
" )\n", | |
"\n", | |
" # return topn each user\n", | |
" if n is not None:\n", | |
" return df.groupby([\"user_id_index\"]).apply(lambda row: row.nlargest(n, 'score')).reset_index(drop=True)\n", | |
"\n", | |
"\n", | |
" return df\n", | |
"\n", | |
"\n", | |
"def topn(model: tf.keras.Model, user_id: int, items: list, n=10):\n", | |
" \"\"\"\n", | |
"\n", | |
" Parameters\n", | |
" ----------\n", | |
" model\n", | |
" user_id: int, str\n", | |
" user_id\n", | |
" items: list\n", | |
" item_ids\n", | |
" n: int, default 10\n", | |
" topn\n", | |
"\n", | |
" Returns\n", | |
" -------\n", | |
" df\n", | |
"\n", | |
" \"\"\"\n", | |
" y_pred = model.predict(x=[np.array(items),\n", | |
" np.array([user_id]*len(items))])\n", | |
"\n", | |
" # to rating\n", | |
" #y_rat = [int(abs(round(i[0]))) for i in y]\n", | |
" y = [i[0] for i in y_pred]\n", | |
" # index of topn\n", | |
" idx = np.argsort(y)[::-1][:n]\n", | |
"\n", | |
" df = pd.DataFrame(\n", | |
" {\n", | |
" \"user_id_index\": [user_id] * n,\n", | |
" \"item_id_index\": np.array(items)[idx],\n", | |
" \"score\": np.array(y)[idx]\n", | |
" }\n", | |
" )\n", | |
"\n", | |
" return df" | |
], | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "z5p5C-j8fKT7", | |
"colab_type": "code", | |
"colab": {} | |
}, | |
"source": [ | |
"# id\n", | |
"n = 5\n", | |
"pred_user = user_to_index[user_id]\n", | |
"pred_items = list(set(X_test['item_id_index']))\n", | |
"df_pred = topn(model_l, pred_user, pred_items, n=n)\n", | |
"\n", | |
"# index to id\n", | |
"df_pred['user_id'] = df_pred['user_id_index'].map(index_to_user)\n", | |
"df_pred['item_id'] = df_pred['item_id_index'].map(index_to_item)\n", | |
"df_pred" | |
], | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "zdvtSt1TfKUA", | |
"colab_type": "code", | |
"colab": {} | |
}, | |
"source": [ | |
"# 全組み合わせのtopn Noneの時は全の組み合わせを取得\n", | |
"df_preds_topn = all(model_l,\n", | |
" X_train['user_id_index'],\n", | |
" X_train['item_id_index'], n=n)\n", | |
"df_preds_topn['user_id'] = df_preds_topn['user_id_index'].map(index_to_user)\n", | |
"df_preds_topn['item_id'] = df_preds_topn['item_id_index'].map(index_to_item)" | |
], | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "6MVhIQZ1fKUD", | |
"colab_type": "code", | |
"colab": {}, | |
"outputId": "ceb25a5b-e3f3-4d49-bcaf-73dd0f7c4979" | |
}, | |
"source": [ | |
"history.history['loss']" | |
], | |
"execution_count": null, | |
"outputs": [ | |
{ | |
"output_type": "execute_result", | |
"data": { | |
"text/plain": [ | |
"[12.634256362915039,\n", | |
" 5.796047687530518,\n", | |
" 2.128021240234375,\n", | |
" 1.5114867687225342,\n", | |
" 1.3187335729599,\n", | |
" 1.2392791509628296,\n", | |
" 1.2014925479888916,\n", | |
" 1.1819627285003662,\n", | |
" 1.170408010482788,\n", | |
" 1.163235068321228,\n", | |
" 1.158171534538269,\n", | |
" 1.154823660850525,\n", | |
" 1.151853084564209,\n", | |
" 1.1492549180984497,\n", | |
" 1.1466904878616333,\n", | |
" 1.144012451171875,\n", | |
" 1.141165018081665,\n", | |
" 1.1381170749664307,\n", | |
" 1.134507179260254,\n", | |
" 1.1303595304489136]" | |
] | |
}, | |
"metadata": { | |
"tags": [] | |
}, | |
"execution_count": 195 | |
} | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "CIsdTxlKfKUI", | |
"colab_type": "code", | |
"colab": {} | |
}, | |
"source": [ | |
"" | |
], | |
"execution_count": null, | |
"outputs": [] | |
} | |
] | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment