Skip to content

Instantly share code, notes, and snippets.

@uni-3
Last active August 2, 2020 04:55
Embed
What would you like to do?
matrix factorization with tensorflow
{
"nbformat": 4,
"nbformat_minor": 0,
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.6"
},
"colab": {
"name": "mf_keras.ipynb",
"provenance": [],
"collapsed_sections": [],
"include_colab_link": true
}
},
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "view-in-github",
"colab_type": "text"
},
"source": [
"<a href=\"https://colab.research.google.com/gist/uni-3/559aeb67e86fb022b8480263b7608b1c/mf_keras.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
]
},
{
"cell_type": "code",
"metadata": {
"id": "5wLON6eufKPP",
"colab_type": "code",
"colab": {}
},
"source": [
"import datetime\n",
"import numpy as np\n",
"import pandas as pd\n",
"import sklearn.model_selection\n",
"import tensorflow as tf\n",
"import matplotlib.pyplot as plt\n",
"import os\n",
"import sys"
],
"execution_count": 2,
"outputs": []
},
{
"cell_type": "markdown",
"metadata": {
"id": "G3HSls6VfKPu",
"colab_type": "text"
},
"source": [
"#### setup dataset"
]
},
{
"cell_type": "code",
"metadata": {
"id": "v79IwzJXfKPw",
"colab_type": "code",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 526
},
"outputId": "ec80c737-9fc5-4a12-a199-2b6a0d31c663"
},
"source": [
"# download dataset\n",
"!mkdir -p data\n",
"!curl 'http://files.grouplens.org/datasets/movielens/ml-100k.zip' -o ./data/ml-100k.zip\n",
"!unzip -o ./data/ml-100k.zip -d data"
],
"execution_count": 18,
"outputs": [
{
"output_type": "stream",
"text": [
" % Total % Received % Xferd Average Speed Time Time Time Current\n",
" Dload Upload Total Spent Left Speed\n",
"\r 0 0 0 0 0 0 0 0 --:--:-- --:--:-- --:--:-- 0\r100 4808k 100 4808k 0 0 12.0M 0 --:--:-- --:--:-- --:--:-- 12.0M\n",
"Archive: ./data/ml-100k.zip\n",
" inflating: data/ml-100k/allbut.pl \n",
" inflating: data/ml-100k/mku.sh \n",
" inflating: data/ml-100k/README \n",
" inflating: data/ml-100k/u.data \n",
" inflating: data/ml-100k/u.genre \n",
" inflating: data/ml-100k/u.info \n",
" inflating: data/ml-100k/u.item \n",
" inflating: data/ml-100k/u.occupation \n",
" inflating: data/ml-100k/u.user \n",
" inflating: data/ml-100k/u1.base \n",
" inflating: data/ml-100k/u1.test \n",
" inflating: data/ml-100k/u2.base \n",
" inflating: data/ml-100k/u2.test \n",
" inflating: data/ml-100k/u3.base \n",
" inflating: data/ml-100k/u3.test \n",
" inflating: data/ml-100k/u4.base \n",
" inflating: data/ml-100k/u4.test \n",
" inflating: data/ml-100k/u5.base \n",
" inflating: data/ml-100k/u5.test \n",
" inflating: data/ml-100k/ua.base \n",
" inflating: data/ml-100k/ua.test \n",
" inflating: data/ml-100k/ub.base \n",
" inflating: data/ml-100k/ub.test \n"
],
"name": "stdout"
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "9R6dj_ctfP71",
"colab_type": "code",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 36
},
"outputId": "d01ef12c-bb98-459c-b6f9-a1e544f723a6"
},
"source": [
"!ls data"
],
"execution_count": 19,
"outputs": [
{
"output_type": "stream",
"text": [
"ml-100k ml-100k.zip\n"
],
"name": "stdout"
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "dA0a3uwXfKP7",
"colab_type": "code",
"colab": {}
},
"source": [
"def id_to_index(df):\n",
" # id to index\n",
" user_to_index = {uid: idx for idx, uid in enumerate(set(df['user_id']))}\n",
" item_to_index = {iid: idx for idx, iid in enumerate(set(df['item_id']))}\n",
" index_to_user = {idx: uid for idx, uid in enumerate(set(df['user_id']))}\n",
" index_to_item = {idx: iid for idx, iid in enumerate(set(df['item_id']))}\n",
" \n",
" return user_to_index, item_to_index, \\\n",
" index_to_user, index_to_item\n",
"\n",
"def load_movielens(input_file=\"./data/ml-100k/u.data\") -> (pd.DataFrame, dict, dict, dict, dict):\n",
" headers = ['user_id', 'item_id', 'rating', 'timestamp']\n",
"\n",
"\n",
" df = pd.read_csv(input_file,\n",
" sep='\\t',\n",
" names=headers, \n",
" header=None,\n",
" dtype={\n",
" 'user_id': np.int32,\n",
" 'item_id': np.int32,\n",
" 'rating': np.float32,\n",
" 'timestamp': np.int32,\n",
" })\n",
"\n",
" # id to index\n",
" user_to_index, item_to_index, \\\n",
" index_to_user, index_to_item = id_to_index(df)\n",
"\n",
" # add index col\n",
" df['user_id_index'] = df['user_id'].map(user_to_index)\n",
" df['item_id_index'] = df['item_id'].map(item_to_index)\n",
" return df, user_to_index, item_to_index,\\\n",
" index_to_user, index_to_item\n"
],
"execution_count": 20,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "0SFz5z2bfKQD",
"colab_type": "code",
"colab": {}
},
"source": [
"def prep_data(df, col_names=['user_id', 'item_id'], target='rating', norm=True, n_nega_sample: int=None) -> (pd.DataFrame, int, int):\n",
" \"\"\"\n",
"\n",
" Parameters\n",
" ----------\n",
" df: pd.DataFrame\n",
" [user_id, item_id, score]\n",
" col_names: list[string]\n",
" name of user_id and item_id\n",
" target: string\n",
" score\n",
" norm: bool\n",
" whether normalize rating or not\n",
" n_nega_sample: int\n",
" number of negative sampling from dataset\n",
" if None do not sampling\n",
" Returns\n",
" -------\n",
"\n",
" \"\"\"\n",
"\n",
" if norm:\n",
" df[target] = 1\n",
" n_user = df[col_names[0]].nunique()\n",
" n_item = df[col_names[1]].nunique()\n",
"\n",
" df_mat = pd.pivot_table(df, index=col_names[0],\n",
" columns=col_names[1], values=target, fill_value=0\n",
" ).stack().reset_index().rename(columns={0:'rating'})\n",
"\n",
" # positive samples\n",
" df_posi = df_mat[df_mat['rating']>0].reset_index(drop=True)\n",
" df_com = df_posi\n",
"\n",
" if n_nega_sample is not None:\n",
" df_nega = df_mat[df_mat['rating']==0].reset_index(drop=True)\n",
"\n",
" df_nega_samples = df_nega.groupby(col_names[0]) \\\n",
" .apply(lambda x: x.sample(n=min(n_nega_sample, len(x)), random_state=14)) \\\n",
" .reset_index(drop=True)\n",
"\n",
" df_com = pd.concat([df_posi, df_nega_samples]).reset_index(drop=True)\n",
"\n",
" return df_com, n_user, n_item\n"
],
"execution_count": 21,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "EUvuur09fKQM",
"colab_type": "code",
"colab": {}
},
"source": [
"df_raw, user_to_index, item_to_index,\\\n",
" index_to_user, index_to_item = load_movielens()\n",
"df_raw.shape\n",
"# prep data\n",
"col_names = ['user_id_index', 'item_id_index']\n",
"target = 'rating'\n",
"df, n_user, n_item = prep_data(df_raw,\n",
" #col_names=col_names,\n",
" n_nega_sample=5, norm=None\n",
" )\n",
"#df = df.astype({'user_id_index': 'int', 'user_id_index': 'int'})"
],
"execution_count": 22,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "8GHGqf3lfKQY",
"colab_type": "code",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 206
},
"outputId": "4699eab1-19bd-4a56-d26e-e9688fcfc6ff"
},
"source": [
"df_raw.head()"
],
"execution_count": 23,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>user_id</th>\n",
" <th>item_id</th>\n",
" <th>rating</th>\n",
" <th>timestamp</th>\n",
" <th>user_id_index</th>\n",
" <th>item_id_index</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>196</td>\n",
" <td>242</td>\n",
" <td>3.0</td>\n",
" <td>881250949</td>\n",
" <td>195</td>\n",
" <td>241</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1</th>\n",
" <td>186</td>\n",
" <td>302</td>\n",
" <td>3.0</td>\n",
" <td>891717742</td>\n",
" <td>185</td>\n",
" <td>301</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2</th>\n",
" <td>22</td>\n",
" <td>377</td>\n",
" <td>1.0</td>\n",
" <td>878887116</td>\n",
" <td>21</td>\n",
" <td>376</td>\n",
" </tr>\n",
" <tr>\n",
" <th>3</th>\n",
" <td>244</td>\n",
" <td>51</td>\n",
" <td>2.0</td>\n",
" <td>880606923</td>\n",
" <td>243</td>\n",
" <td>50</td>\n",
" </tr>\n",
" <tr>\n",
" <th>4</th>\n",
" <td>166</td>\n",
" <td>346</td>\n",
" <td>1.0</td>\n",
" <td>886397596</td>\n",
" <td>165</td>\n",
" <td>345</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>"
],
"text/plain": [
" user_id item_id rating timestamp user_id_index item_id_index\n",
"0 196 242 3.0 881250949 195 241\n",
"1 186 302 3.0 891717742 185 301\n",
"2 22 377 1.0 878887116 21 376\n",
"3 244 51 2.0 880606923 243 50\n",
"4 166 346 1.0 886397596 165 345"
]
},
"metadata": {
"tags": []
},
"execution_count": 23
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "u-G9piITfKQh",
"colab_type": "code",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 206
},
"outputId": "e2c4c6b0-b9de-4be0-9f05-153715552961"
},
"source": [
"df.head()"
],
"execution_count": 24,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>user_id</th>\n",
" <th>item_id</th>\n",
" <th>rating</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>1</td>\n",
" <td>1</td>\n",
" <td>5</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1</th>\n",
" <td>1</td>\n",
" <td>2</td>\n",
" <td>3</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2</th>\n",
" <td>1</td>\n",
" <td>3</td>\n",
" <td>4</td>\n",
" </tr>\n",
" <tr>\n",
" <th>3</th>\n",
" <td>1</td>\n",
" <td>4</td>\n",
" <td>3</td>\n",
" </tr>\n",
" <tr>\n",
" <th>4</th>\n",
" <td>1</td>\n",
" <td>5</td>\n",
" <td>3</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>"
],
"text/plain": [
" user_id item_id rating\n",
"0 1 1 5\n",
"1 1 2 3\n",
"2 1 3 4\n",
"3 1 4 3\n",
"4 1 5 3"
]
},
"metadata": {
"tags": []
},
"execution_count": 24
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "fsINgd_RfKQo",
"colab_type": "code",
"colab": {}
},
"source": [
"X_train, X_test, y_train, y_test \\\n",
" = sklearn.model_selection.train_test_split(df[['user_id', 'item_id']], df[target]\n",
" , test_size=0.05, random_state=1)\n",
" "
],
"execution_count": 25,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "ARWEtqnYfKQw",
"colab_type": "code",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 36
},
"outputId": "10c48f0b-c529-4f9e-8f45-87b8a3c34c84"
},
"source": [
"X_train.shape"
],
"execution_count": 26,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"(99479, 2)"
]
},
"metadata": {
"tags": []
},
"execution_count": 26
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "2le9n0mSfKQ7",
"colab_type": "code",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 36
},
"outputId": "651ec478-4b2e-431d-bca5-66023f2f69ba"
},
"source": [
"X_test.shape"
],
"execution_count": 27,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"(5236, 2)"
]
},
"metadata": {
"tags": []
},
"execution_count": 27
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "fXJYnroCfKRG",
"colab_type": "code",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 36
},
"outputId": "77ee3de1-3f36-4f05-f30b-04e74f7a3591"
},
"source": [
"X_train['user_id'].nunique()"
],
"execution_count": 28,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"943"
]
},
"metadata": {
"tags": []
},
"execution_count": 28
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "LKeZLlLVfKRN",
"colab_type": "code",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 36
},
"outputId": "7811ae25-e96e-46f3-9ad9-4a6ab01a0f3e"
},
"source": [
"X_test['user_id'].nunique()"
],
"execution_count": 29,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"877"
]
},
"metadata": {
"tags": []
},
"execution_count": 29
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "G7C06ajwfKRZ",
"colab_type": "code",
"colab": {}
},
"source": [
"user_id_input = tf.keras.layers.Input(shape=[1], name='user')\n",
"item_id_input = tf.keras.layers.Input(shape=[1], name='item')\n",
"\n",
"\n",
"embedding_size = 8\n",
"user_embedding = tf.keras.layers.Embedding(output_dim=embedding_size, input_dim=n_user+1,\n",
" #embeddings_initializer='glorot_uniform',\n",
" #embeddings_regularizer=tf.keras.regularizers.l2(l=0.01),\n",
" input_length=1,\n",
" name='user_embedding')(user_id_input)\n",
"item_embedding = tf.keras.layers.Embedding(output_dim=embedding_size, input_dim=n_item+1,\n",
" #embeddings_initializer='glorot_uniform',\n",
" #embeddings_regularizer=tf.keras.regularizers.l2(l=0.01),\n",
" input_length=1,\n",
" name='item_embedding')(item_id_input)\n",
"\n",
"user_vecs = tf.keras.layers.Reshape([embedding_size])(user_embedding)\n",
"item_vecs = tf.keras.layers.Reshape([embedding_size])(item_embedding)\n",
"\n",
"# Final prediction layer\n",
"# inner product(score)\n",
"y = tf.keras.layers.Dot(1, normalize=False, name=\"concat\")([user_vecs, item_vecs])\n",
"#y = tf.linalg.matmul([user_vecs, item_vecs])\n",
"\n",
"model = tf.keras.Model(inputs=[user_id_input, item_id_input], outputs=y)\n",
"\n",
"model.compile(loss='mse',\n",
" optimizer=\"adam\"\n",
" )"
],
"execution_count": 30,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "h6tqn0vWfKRj",
"colab_type": "code",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 74
},
"outputId": "a4a60d7b-490b-4bf3-cc86-69bec60b2995"
},
"source": [
"!pip install pydot-ng graphviz "
],
"execution_count": 31,
"outputs": [
{
"output_type": "stream",
"text": [
"Requirement already satisfied: pydot-ng in /usr/local/lib/python3.6/dist-packages (2.0.0)\n",
"Requirement already satisfied: graphviz in /usr/local/lib/python3.6/dist-packages (0.10.1)\n",
"Requirement already satisfied: pyparsing>=2.0.1 in /usr/local/lib/python3.6/dist-packages (from pydot-ng) (2.4.7)\n"
],
"name": "stdout"
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "pDfJmM3_fKRr",
"colab_type": "code",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 369
},
"outputId": "537daf61-f6be-4bd1-9f81-0088e8be180e"
},
"source": [
"# show model\n",
"tf.keras.utils.plot_model(\n",
" model, #to_file='model.png',\n",
" show_shapes=False, show_layer_names=True,\n",
" rankdir='TB', expand_nested=True, dpi=96\n",
")\n"
],
"execution_count": 32,
"outputs": [
{
"output_type": "execute_result",
"data": {
"image/png": "\n",
"text/plain": [
"<IPython.core.display.Image object>"
]
},
"metadata": {
"tags": []
},
"execution_count": 32
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "sviUZ9_UfKRx",
"colab_type": "code",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 451
},
"outputId": "398f354f-85c0-4334-889e-da48b23265fa"
},
"source": [
"model.summary()"
],
"execution_count": 33,
"outputs": [
{
"output_type": "stream",
"text": [
"Model: \"model\"\n",
"__________________________________________________________________________________________________\n",
"Layer (type) Output Shape Param # Connected to \n",
"==================================================================================================\n",
"user (InputLayer) [(None, 1)] 0 \n",
"__________________________________________________________________________________________________\n",
"item (InputLayer) [(None, 1)] 0 \n",
"__________________________________________________________________________________________________\n",
"user_embedding (Embedding) (None, 1, 8) 7552 user[0][0] \n",
"__________________________________________________________________________________________________\n",
"item_embedding (Embedding) (None, 1, 8) 13464 item[0][0] \n",
"__________________________________________________________________________________________________\n",
"reshape (Reshape) (None, 8) 0 user_embedding[0][0] \n",
"__________________________________________________________________________________________________\n",
"reshape_1 (Reshape) (None, 8) 0 item_embedding[0][0] \n",
"__________________________________________________________________________________________________\n",
"concat (Dot) (None, 1) 0 reshape[0][0] \n",
" reshape_1[0][0] \n",
"==================================================================================================\n",
"Total params: 21,016\n",
"Trainable params: 21,016\n",
"Non-trainable params: 0\n",
"__________________________________________________________________________________________________\n"
],
"name": "stdout"
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "2kQ3kZPtfKR5",
"colab_type": "code",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 1000
},
"outputId": "3dce9058-1bac-4195-ee2a-0f11f9055af8"
},
"source": [
"import time\n",
"\n",
"checkpoint_callback = tf.keras.callbacks.ModelCheckpoint('./models_fm', \n",
" monitor='val_loss',\n",
" save_best_only=True,\n",
" freq='epoch'\n",
" )\n",
"log_dir = os.path.join(\"./models_fm/logs\" , datetime.datetime.now().strftime(\"%Y%m%d-%H%M%S\"))\n",
"tensorboard_callback = tf.keras.callbacks.TensorBoard(log_dir=log_dir, histogram_freq=1)\n",
"\n",
"history = model.fit(x=[X_train[\"user_id\"], X_train[\"item_id\"]]\n",
" ,y=y_train\n",
" ,batch_size=128, epochs=20\n",
" ,validation_split=0.1\n",
" ,callbacks=[checkpoint_callback, tensorboard_callback]\n",
" ,shuffle=True\n",
" #,verbose=2\n",
")\n",
"\n",
"import pickle\n",
"with open(os.path.join(\"models_fm\", \"history.pkl\"), \"wb\") as f:\n",
" pickle.dump(history.history, f)"
],
"execution_count": 34,
"outputs": [
{
"output_type": "stream",
"text": [
"Epoch 1/20\n",
"699/700 [============================>.] - ETA: 0s - loss: 12.4237WARNING:tensorflow:From /usr/local/lib/python3.6/dist-packages/tensorflow/python/ops/resource_variable_ops.py:1817: calling BaseResourceVariable.__init__ (from tensorflow.python.ops.resource_variable_ops) with constraint is deprecated and will be removed in a future version.\n",
"Instructions for updating:\n",
"If using Keras pass *_constraint arguments to layers.\n",
"INFO:tensorflow:Assets written to: ./models_fm/assets\n",
"700/700 [==============================] - 2s 3ms/step - loss: 12.4219 - val_loss: 9.6668\n",
"Epoch 2/20\n",
"681/700 [============================>.] - ETA: 0s - loss: 5.5264INFO:tensorflow:Assets written to: ./models_fm/assets\n",
"700/700 [==============================] - 2s 3ms/step - loss: 5.4549 - val_loss: 2.7801\n",
"Epoch 3/20\n",
"682/700 [============================>.] - ETA: 0s - loss: 2.0878INFO:tensorflow:Assets written to: ./models_fm/assets\n",
"700/700 [==============================] - 2s 3ms/step - loss: 2.0768 - val_loss: 1.6733\n",
"Epoch 4/20\n",
"697/700 [============================>.] - ETA: 0s - loss: 1.4977INFO:tensorflow:Assets written to: ./models_fm/assets\n",
"700/700 [==============================] - 2s 2ms/step - loss: 1.4976 - val_loss: 1.3814\n",
"Epoch 5/20\n",
"699/700 [============================>.] - ETA: 0s - loss: 1.3131INFO:tensorflow:Assets written to: ./models_fm/assets\n",
"700/700 [==============================] - 2s 3ms/step - loss: 1.3134 - val_loss: 1.2714\n",
"Epoch 6/20\n",
"698/700 [============================>.] - ETA: 0s - loss: 1.2366INFO:tensorflow:Assets written to: ./models_fm/assets\n",
"700/700 [==============================] - 2s 2ms/step - loss: 1.2371 - val_loss: 1.2249\n",
"Epoch 7/20\n",
"692/700 [============================>.] - ETA: 0s - loss: 1.2012INFO:tensorflow:Assets written to: ./models_fm/assets\n",
"700/700 [==============================] - 2s 2ms/step - loss: 1.2008 - val_loss: 1.2033\n",
"Epoch 8/20\n",
"675/700 [===========================>..] - ETA: 0s - loss: 1.1818INFO:tensorflow:Assets written to: ./models_fm/assets\n",
"700/700 [==============================] - 2s 2ms/step - loss: 1.1818 - val_loss: 1.1919\n",
"Epoch 9/20\n",
"675/700 [===========================>..] - ETA: 0s - loss: 1.1696INFO:tensorflow:Assets written to: ./models_fm/assets\n",
"700/700 [==============================] - 2s 3ms/step - loss: 1.1706 - val_loss: 1.1882\n",
"Epoch 10/20\n",
"678/700 [============================>.] - ETA: 0s - loss: 1.1647INFO:tensorflow:Assets written to: ./models_fm/assets\n",
"700/700 [==============================] - 2s 3ms/step - loss: 1.1639 - val_loss: 1.1855\n",
"Epoch 11/20\n",
"700/700 [==============================] - 1s 2ms/step - loss: 1.1593 - val_loss: 1.1859\n",
"Epoch 12/20\n",
"695/700 [============================>.] - ETA: 0s - loss: 1.1563INFO:tensorflow:Assets written to: ./models_fm/assets\n",
"700/700 [==============================] - 2s 3ms/step - loss: 1.1562 - val_loss: 1.1836\n",
"Epoch 13/20\n",
"693/700 [============================>.] - ETA: 0s - loss: 1.1534INFO:tensorflow:Assets written to: ./models_fm/assets\n",
"700/700 [==============================] - 2s 3ms/step - loss: 1.1537 - val_loss: 1.1825\n",
"Epoch 14/20\n",
"700/700 [==============================] - 1s 2ms/step - loss: 1.1516 - val_loss: 1.1844\n",
"Epoch 15/20\n",
"675/700 [===========================>..] - ETA: 0s - loss: 1.1484INFO:tensorflow:Assets written to: ./models_fm/assets\n",
"700/700 [==============================] - 2s 3ms/step - loss: 1.1494 - val_loss: 1.1810\n",
"Epoch 16/20\n",
"700/700 [==============================] - 1s 2ms/step - loss: 1.1475 - val_loss: 1.1822\n",
"Epoch 17/20\n",
"700/700 [==============================] - 1s 2ms/step - loss: 1.1457 - val_loss: 1.1812\n",
"Epoch 18/20\n",
"700/700 [==============================] - 1s 2ms/step - loss: 1.1438 - val_loss: 1.1815\n",
"Epoch 19/20\n",
"700/700 [==============================] - 1s 2ms/step - loss: 1.1416 - val_loss: 1.1813\n",
"Epoch 20/20\n",
"681/700 [============================>.] - ETA: 0s - loss: 1.1383INFO:tensorflow:Assets written to: ./models_fm/assets\n",
"700/700 [==============================] - 2s 2ms/step - loss: 1.1389 - val_loss: 1.1804\n"
],
"name": "stdout"
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "TKC1qsN-fKSA",
"colab_type": "code",
"colab": {}
},
"source": [
"import pickle\n",
"with open(os.path.join(\"models_fm\", \"history.pkl\"), \"wb\") as f:\n",
" pickle.dump(history.history, f)"
],
"execution_count": 35,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "1ZsBmNlOfKSH",
"colab_type": "code",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 36
},
"outputId": "f563fad7-bb19-4554-e39e-4a4abd29236f"
},
"source": [
"model.layers[3].get_weights()[0].shape"
],
"execution_count": 36,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"(1683, 8)"
]
},
"metadata": {
"tags": []
},
"execution_count": 36
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "DWVtifB6fKSM",
"colab_type": "code",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 36
},
"outputId": "ddadb22e-c25b-48c1-cc26-18c9fa3b381d"
},
"source": [
"model.layers[2].get_weights()[0].shape"
],
"execution_count": 37,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"(944, 8)"
]
},
"metadata": {
"tags": []
},
"execution_count": 37
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "w2IDZ5FEfKSW",
"colab_type": "code",
"colab": {}
},
"source": [
"# output embeddings\n",
"pd.DataFrame(model.layers[2].get_weights()[0]).to_csv('./user_emb_8.csv')\n",
"pd.DataFrame(model.layers[3].get_weights()[0]).to_csv('./item_emb_8.csv')\n"
],
"execution_count": 38,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "eoL-CMGWfKSg",
"colab_type": "code",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 206
},
"outputId": "40f783ea-09c6-4289-e803-827f48b68168"
},
"source": [
"df_scores = pd.DataFrame(model.layers[2].get_weights()[0] @ model.layers[3].get_weights()[0].T).stack().reset_index()\n",
"df_scores.columns = ['user_id', 'item_id', 'rating']\n",
"df_scores.to_csv('./scores.csv', index=False)\n",
"df_scores.head()\n"
],
"execution_count": 39,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>user_id</th>\n",
" <th>item_id</th>\n",
" <th>rating</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0.001375</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1</th>\n",
" <td>0</td>\n",
" <td>1</td>\n",
" <td>-0.060442</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2</th>\n",
" <td>0</td>\n",
" <td>2</td>\n",
" <td>-0.048946</td>\n",
" </tr>\n",
" <tr>\n",
" <th>3</th>\n",
" <td>0</td>\n",
" <td>3</td>\n",
" <td>-0.048313</td>\n",
" </tr>\n",
" <tr>\n",
" <th>4</th>\n",
" <td>0</td>\n",
" <td>4</td>\n",
" <td>-0.050197</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>"
],
"text/plain": [
" user_id item_id rating\n",
"0 0 0 0.001375\n",
"1 0 1 -0.060442\n",
"2 0 2 -0.048946\n",
"3 0 3 -0.048313\n",
"4 0 4 -0.050197"
]
},
"metadata": {
"tags": []
},
"execution_count": 39
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "TSybDjVyfKSm",
"colab_type": "code",
"colab": {}
},
"source": [
"# umap\n",
"import pandas as pd\n",
"import umap\n",
"import time\n",
"import matplotlib.pyplot as plt\n",
"\n",
"#df_user_emb = pd.read_csv('./user_emb_8.csv')\n",
"#df_item_emb = pd.read_csv('./item_emb_8.csv')\n",
"#embedding = umap.UMAP().fit_transform(df_item_emb[[\"0\",\"1\",\"2\",\"3\",\"4\",\"5\",\"6\",\"7\"]].values)\n",
"#pd.DataFrame(embedding).to_csv('item_emb_umap.csv')"
],
"execution_count": 40,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "fCTc45hqfKSs",
"colab_type": "code",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 36
},
"outputId": "fec48ccf-bb58-454e-80a5-cdbf9161354d"
},
"source": [
"pd.DataFrame(model.layers[2].get_weights()[0] @ model.layers[3].get_weights()[0].T).shape"
],
"execution_count": 41,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"(944, 1683)"
]
},
"metadata": {
"tags": []
},
"execution_count": 41
}
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "Q3bJS9QofKSx",
"colab_type": "text"
},
"source": [
"#### https://www.tensorflow.org/tensorboard/tensorboard_in_notebooks"
]
},
{
"cell_type": "code",
"metadata": {
"id": "vABcwTIifKSy",
"colab_type": "code",
"colab": {}
},
"source": [
""
],
"execution_count": 41,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "KXE5j0eFfKS4",
"colab_type": "code",
"colab": {}
},
"source": [
"!kill 437"
],
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "X_vFtJ3BfKS8",
"colab_type": "code",
"colab": {},
"outputId": "352afdaf-e71b-4bc5-8653-a528205e8142"
},
"source": [
"from tensorboard import notebook\n",
"notebook.list() # View open TensorBoard instances"
],
"execution_count": null,
"outputs": [
{
"output_type": "stream",
"text": [
"No known TensorBoard instances running.\n"
],
"name": "stdout"
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "A7gnzyCCfKTB",
"colab_type": "code",
"colab": {}
},
"source": [
""
],
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "shMtu5fefKTG",
"colab_type": "code",
"colab": {}
},
"source": [
"%load_ext tensorboard"
],
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "Wrn5VyNEfKTL",
"colab_type": "code",
"colab": {},
"outputId": "aa89e246-ccf7-4318-d239-44f6280bb1d6"
},
"source": [
"%reload_ext tensorboard\n",
"%tensorboard --logdir ./models_fm/logs --host 0.0.0.0"
],
"execution_count": null,
"outputs": [
{
"output_type": "display_data",
"data": {
"text/html": [
"\n",
" <iframe id=\"tensorboard-frame-8722f0df074f4d28\" width=\"100%\" height=\"800\" frameborder=\"0\">\n",
" </iframe>\n",
" <script>\n",
" (function() {\n",
" const frame = document.getElementById(\"tensorboard-frame-8722f0df074f4d28\");\n",
" const url = new URL(\"/\", window.location);\n",
" url.port = 6006;\n",
" frame.src = url;\n",
" })();\n",
" </script>\n",
" "
],
"text/plain": [
"<IPython.core.display.HTML object>"
]
},
"metadata": {
"tags": []
}
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "anW_lEhufKTO",
"colab_type": "code",
"colab": {}
},
"source": [
"y_pred = model.predict(x=[X_test[\"user_id\"], X_test[\"item_id\"]])\n"
],
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "8eJWWrGSfKTU",
"colab_type": "code",
"colab": {},
"outputId": "8f3ed7a7-9686-4b9b-a29f-cb8803b09f97"
},
"source": [
"y_pred.shape"
],
"execution_count": null,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"(20000, 1)"
]
},
"metadata": {
"tags": []
},
"execution_count": 27
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "qNN3BO0IfKTa",
"colab_type": "code",
"colab": {},
"outputId": "76b03505-0c3b-4a16-95a0-369ad3257b46"
},
"source": [
"print(f'''user {X_test[\"user_id\"].values[0]}のitem {X_test[\"item_id\"].values[0]}に対する\\n\n",
"予測値: {y_pred[0][0]}\\n\n",
"真の値: {y_test.values[0]}\n",
"''')"
],
"execution_count": null,
"outputs": [
{
"output_type": "stream",
"text": [
"user 265のitem 284に対する\n",
"\n",
"予測値: 3.259321689605713\n",
"\n",
"真の値: 4\n",
"\n"
],
"name": "stdout"
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "EIDZr1JtfKTf",
"colab_type": "code",
"colab": {}
},
"source": [
"from sklearn.metrics import mean_squared_error\n",
"# calc RMSE\n",
"np.sqrt(mean_squared_error(y_test, y_pred))"
],
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "sSehlno-fKTm",
"colab_type": "code",
"colab": {},
"outputId": "c0e12881-2498-41cb-e625-2881b2c3d1f3"
},
"source": [
"# 重みとオプティマイザを含む全く同じモデルを再作成\n",
"model_s = './models_fm'\n",
"model_l = tf.keras.models.load_model(model_s)\n",
"\n",
"model_l.summary()\n"
],
"execution_count": null,
"outputs": [
{
"output_type": "stream",
"text": [
"Model: \"model\"\n",
"__________________________________________________________________________________________________\n",
"Layer (type) Output Shape Param # Connected to \n",
"==================================================================================================\n",
"user (InputLayer) [(None, 1)] 0 \n",
"__________________________________________________________________________________________________\n",
"item (InputLayer) [(None, 1)] 0 \n",
"__________________________________________________________________________________________________\n",
"user_embedding (Embedding) (None, 1, 30) 2400000 user[0][0] \n",
"__________________________________________________________________________________________________\n",
"item_embedding (Embedding) (None, 1, 30) 2400000 item[0][0] \n",
"__________________________________________________________________________________________________\n",
"reshape (Reshape) (None, 30) 0 user_embedding[0][0] \n",
"__________________________________________________________________________________________________\n",
"reshape_1 (Reshape) (None, 30) 0 item_embedding[0][0] \n",
"__________________________________________________________________________________________________\n",
"dot (Dot) (None, 1) 0 reshape[0][0] \n",
" reshape_1[0][0] \n",
"==================================================================================================\n",
"Total params: 4,800,000\n",
"Trainable params: 4,800,000\n",
"Non-trainable params: 0\n",
"__________________________________________________________________________________________________\n"
],
"name": "stdout"
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "dZ_qj4ogfKTq",
"colab_type": "code",
"colab": {}
},
"source": [
"y_pred = model_l.predict(x=[X_test[\"user_id\"], X_test[\"item_id\"]])\n"
],
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "JwUAGj8SfKTt",
"colab_type": "code",
"colab": {},
"outputId": "43a2fd06-3a40-4373-cf2c-d8d41761e86e"
},
"source": [
"from sklearn.metrics import mean_squared_error\n",
"# calc RMSE\n",
"np.sqrt(mean_squared_error(y_test, y_pred))"
],
"execution_count": null,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"0.93278754"
]
},
"metadata": {
"tags": []
},
"execution_count": 28
}
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "fbltKPwrfKTy",
"colab_type": "text"
},
"source": [
"### predict topn"
]
},
{
"cell_type": "code",
"metadata": {
"id": "97LJTYJ3fKTz",
"colab_type": "code",
"colab": {}
},
"source": [
"import numpy as np\n",
"import pandas as pd\n",
"import tensorflow as tf\n",
"\n",
"import itertools\n",
"\n",
"def all(model: tf.keras.Model, users: list, items: list, n: int=None):\n",
" \"\"\"\n",
" 全組み合わせについてスコアを取得\n",
"\n",
" Parameters\n",
" ----------\n",
" model: tf.keras.Model\n",
" user: list\n",
" items: list\n",
" n: int, default None\n",
" if is not None, return topn each user\n",
"\n",
" Returns\n",
" -------\n",
" df\n",
"\n",
" \"\"\"\n",
" # 全組み合わせ\n",
" ul = np.array([u for u, _ in itertools.product(list(set(users)),\n",
" list(set(items)))])\n",
" il = np.array([i for _, i in itertools.product(list(set(users)),\n",
" list(set(items)))])\n",
" y_pred = model.predict(x=[il, ul]).flatten()\n",
"\n",
" df = pd.DataFrame(\n",
" {\n",
" \"user_id_index\": ul,\n",
" \"item_id_index\": il,\n",
" \"score\": np.array(y_pred)\n",
" }\n",
" )\n",
"\n",
" # return topn each user\n",
" if n is not None:\n",
" return df.groupby([\"user_id_index\"]).apply(lambda row: row.nlargest(n, 'score')).reset_index(drop=True)\n",
"\n",
"\n",
" return df\n",
"\n",
"\n",
"def topn(model: tf.keras.Model, user_id: int, items: list, n=10):\n",
" \"\"\"\n",
"\n",
" Parameters\n",
" ----------\n",
" model\n",
" user_id: int, str\n",
" user_id\n",
" items: list\n",
" item_ids\n",
" n: int, default 10\n",
" topn\n",
"\n",
" Returns\n",
" -------\n",
" df\n",
"\n",
" \"\"\"\n",
" y_pred = model.predict(x=[np.array(items),\n",
" np.array([user_id]*len(items))])\n",
"\n",
" # to rating\n",
" #y_rat = [int(abs(round(i[0]))) for i in y]\n",
" y = [i[0] for i in y_pred]\n",
" # index of topn\n",
" idx = np.argsort(y)[::-1][:n]\n",
"\n",
" df = pd.DataFrame(\n",
" {\n",
" \"user_id_index\": [user_id] * n,\n",
" \"item_id_index\": np.array(items)[idx],\n",
" \"score\": np.array(y)[idx]\n",
" }\n",
" )\n",
"\n",
" return df"
],
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "z5p5C-j8fKT7",
"colab_type": "code",
"colab": {}
},
"source": [
"# id\n",
"n = 5\n",
"pred_user = user_to_index[user_id]\n",
"pred_items = list(set(X_test['item_id_index']))\n",
"df_pred = topn(model_l, pred_user, pred_items, n=n)\n",
"\n",
"# index to id\n",
"df_pred['user_id'] = df_pred['user_id_index'].map(index_to_user)\n",
"df_pred['item_id'] = df_pred['item_id_index'].map(index_to_item)\n",
"df_pred"
],
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "zdvtSt1TfKUA",
"colab_type": "code",
"colab": {}
},
"source": [
"# 全組み合わせのtopn Noneの時は全の組み合わせを取得\n",
"df_preds_topn = all(model_l,\n",
" X_train['user_id_index'],\n",
" X_train['item_id_index'], n=n)\n",
"df_preds_topn['user_id'] = df_preds_topn['user_id_index'].map(index_to_user)\n",
"df_preds_topn['item_id'] = df_preds_topn['item_id_index'].map(index_to_item)"
],
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "6MVhIQZ1fKUD",
"colab_type": "code",
"colab": {},
"outputId": "ceb25a5b-e3f3-4d49-bcaf-73dd0f7c4979"
},
"source": [
"history.history['loss']"
],
"execution_count": null,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"[12.634256362915039,\n",
" 5.796047687530518,\n",
" 2.128021240234375,\n",
" 1.5114867687225342,\n",
" 1.3187335729599,\n",
" 1.2392791509628296,\n",
" 1.2014925479888916,\n",
" 1.1819627285003662,\n",
" 1.170408010482788,\n",
" 1.163235068321228,\n",
" 1.158171534538269,\n",
" 1.154823660850525,\n",
" 1.151853084564209,\n",
" 1.1492549180984497,\n",
" 1.1466904878616333,\n",
" 1.144012451171875,\n",
" 1.141165018081665,\n",
" 1.1381170749664307,\n",
" 1.134507179260254,\n",
" 1.1303595304489136]"
]
},
"metadata": {
"tags": []
},
"execution_count": 195
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "CIsdTxlKfKUI",
"colab_type": "code",
"colab": {}
},
"source": [
""
],
"execution_count": null,
"outputs": []
}
]
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment