Last active
June 22, 2021 10:09
-
-
Save napjon/c47d62c08619d16278ea8eba4a5c1c69 to your computer and use it in GitHub Desktop.
Add Feature FastAI Use Collab for topN recommendation
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"nbformat": 4, | |
"nbformat_minor": 0, | |
"metadata": { | |
"colab": { | |
"name": "collab-topN-recommend.ipynb", | |
"provenance": [], | |
"collapsed_sections": [] | |
}, | |
"kernelspec": { | |
"name": "python3", | |
"display_name": "Python 3" | |
} | |
}, | |
"cells": [ | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"id": "muMLDfmfkoNZ", | |
"colab_type": "text" | |
}, | |
"source": [ | |
"# Overview\n", | |
"The results of recommendation engine sometimes predicting in real-time. That is if one user visit/bought particular product, we want to recommend product that we thought they will be interested in. " | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "uTuqH3UClDaJ", | |
"colab_type": "code", | |
"colab": {} | |
}, | |
"source": [ | |
"from fastai.collab import *\n", | |
"from fastai.tabular import *" | |
], | |
"execution_count": 1, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "WwB3kJEFlMpS", | |
"colab_type": "code", | |
"colab": {} | |
}, | |
"source": [ | |
"user, item, title = \"userId\", \"movieId\", \"title\"\n", | |
"path = untar_data(URLs.ML_SAMPLE)" | |
], | |
"execution_count": 2, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "bv3kku5OlSED", | |
"colab_type": "code", | |
"colab": { | |
"base_uri": "https://localhost:8080/", | |
"height": 204 | |
}, | |
"outputId": "6b1fed8d-e086-4464-beca-93473462d27a" | |
}, | |
"source": [ | |
"ratings = pd.read_csv(path/\"ratings.csv\")\n", | |
"ratings.head()" | |
], | |
"execution_count": 3, | |
"outputs": [ | |
{ | |
"output_type": "execute_result", | |
"data": { | |
"text/html": [ | |
"<div>\n", | |
"<style scoped>\n", | |
" .dataframe tbody tr th:only-of-type {\n", | |
" vertical-align: middle;\n", | |
" }\n", | |
"\n", | |
" .dataframe tbody tr th {\n", | |
" vertical-align: top;\n", | |
" }\n", | |
"\n", | |
" .dataframe thead th {\n", | |
" text-align: right;\n", | |
" }\n", | |
"</style>\n", | |
"<table border=\"1\" class=\"dataframe\">\n", | |
" <thead>\n", | |
" <tr style=\"text-align: right;\">\n", | |
" <th></th>\n", | |
" <th>userId</th>\n", | |
" <th>movieId</th>\n", | |
" <th>rating</th>\n", | |
" <th>timestamp</th>\n", | |
" </tr>\n", | |
" </thead>\n", | |
" <tbody>\n", | |
" <tr>\n", | |
" <th>0</th>\n", | |
" <td>73</td>\n", | |
" <td>1097</td>\n", | |
" <td>4.0</td>\n", | |
" <td>1255504951</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>1</th>\n", | |
" <td>561</td>\n", | |
" <td>924</td>\n", | |
" <td>3.5</td>\n", | |
" <td>1172695223</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>2</th>\n", | |
" <td>157</td>\n", | |
" <td>260</td>\n", | |
" <td>3.5</td>\n", | |
" <td>1291598691</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>3</th>\n", | |
" <td>358</td>\n", | |
" <td>1210</td>\n", | |
" <td>5.0</td>\n", | |
" <td>957481884</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>4</th>\n", | |
" <td>130</td>\n", | |
" <td>316</td>\n", | |
" <td>2.0</td>\n", | |
" <td>1138999234</td>\n", | |
" </tr>\n", | |
" </tbody>\n", | |
"</table>\n", | |
"</div>" | |
], | |
"text/plain": [ | |
" userId movieId rating timestamp\n", | |
"0 73 1097 4.0 1255504951\n", | |
"1 561 924 3.5 1172695223\n", | |
"2 157 260 3.5 1291598691\n", | |
"3 358 1210 5.0 957481884\n", | |
"4 130 316 2.0 1138999234" | |
] | |
}, | |
"metadata": { | |
"tags": [] | |
}, | |
"execution_count": 3 | |
} | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "dwsa2wUhlSgG", | |
"colab_type": "code", | |
"colab": { | |
"base_uri": "https://localhost:8080/", | |
"height": 142 | |
}, | |
"outputId": "f43c347b-781b-4ca5-aa33-a8e69729ae25" | |
}, | |
"source": [ | |
"data = CollabDataBunch.from_df(ratings, seed=42)\n", | |
"learn = collab_learner(data, n_factors=50, y_range = [0, 5.5])\n", | |
"learn.fit_one_cycle(3, 5e-3)" | |
], | |
"execution_count": 4, | |
"outputs": [ | |
{ | |
"output_type": "display_data", | |
"data": { | |
"text/html": [ | |
"<table border=\"1\" class=\"dataframe\">\n", | |
" <thead>\n", | |
" <tr style=\"text-align: left;\">\n", | |
" <th>epoch</th>\n", | |
" <th>train_loss</th>\n", | |
" <th>valid_loss</th>\n", | |
" <th>time</th>\n", | |
" </tr>\n", | |
" </thead>\n", | |
" <tbody>\n", | |
" <tr>\n", | |
" <td>0</td>\n", | |
" <td>1.605562</td>\n", | |
" <td>0.914348</td>\n", | |
" <td>00:00</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <td>1</td>\n", | |
" <td>0.859106</td>\n", | |
" <td>0.680887</td>\n", | |
" <td>00:00</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <td>2</td>\n", | |
" <td>0.672545</td>\n", | |
" <td>0.677157</td>\n", | |
" <td>00:00</td>\n", | |
" </tr>\n", | |
" </tbody>\n", | |
"</table>" | |
], | |
"text/plain": [ | |
"<IPython.core.display.HTML object>" | |
] | |
}, | |
"metadata": { | |
"tags": [] | |
} | |
} | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"id": "kiyHrKLalhnO", | |
"colab_type": "text" | |
}, | |
"source": [ | |
"With these functions we can predict in real-time if needed, top 10 product for a particular user, or in contrast top 10 users that we can use as campaign whitelist for this product. " | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "pUCpj-mFlvTB", | |
"colab_type": "code", | |
"colab": {} | |
}, | |
"source": [ | |
"def predict_topN(learn, user=None, item=None, N=10, not_used=True, for_user=True):\n", | |
"\n", | |
" xd = learn.data.train_ds.x.classes\n", | |
" userid, itemid = xd.keys()\n", | |
" users, items = xd.values()\n", | |
"\n", | |
" if not_used:\n", | |
" dfi = learn.data.train_ds.x.inner_df\n", | |
" if for_user:\n", | |
" f_items = dfi.loc[dfi[userid] == int(user), itemid].unique()\n", | |
" items = dfi.loc[~dfi[itemid].isin(f_items), itemid].astype(str).unique()\n", | |
"\n", | |
" else:\n", | |
" f_users = dfi.loc[dfi[itemid] == int(item), userid].unique()\n", | |
" users = dfi.loc[~dfi[userid].isin(f_users), userid].astype(str).unique()\n", | |
"\n", | |
" if for_user:\n", | |
" users = [str(user)]\n", | |
" output = items\n", | |
" else:\n", | |
" items = [str(item)]\n", | |
" output = users\n", | |
" \n", | |
" \n", | |
" predictions = learn.model.forward(learn.get_idx(users, is_item=False),\n", | |
" learn.get_idx(items))\n", | |
" \n", | |
" return sorted(zip(output, predictions), key=lambda x:x[1], reverse=True)[:N]\n", | |
"\n", | |
"\n", | |
"def predict_topN_item_for_user(learn, user, N=10, not_used=True):\n", | |
" return predict_topN(learn, user=user, N=N, not_used=not_used, for_user=True)\n", | |
"\n", | |
"def predict_topN_user_for_item(learn, item, N=10, not_used=True):\n", | |
" return predict_topN(learn, item=item, N=N, not_used=not_used, for_user=False)" | |
], | |
"execution_count": 5, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "HXeIw6xAlxcQ", | |
"colab_type": "code", | |
"colab": { | |
"base_uri": "https://localhost:8080/", | |
"height": 187 | |
}, | |
"outputId": "2cac3a22-d820-4f84-abe8-ba97120baa34" | |
}, | |
"source": [ | |
"predict_topN_item_for_user(learn, 561)" | |
], | |
"execution_count": 6, | |
"outputs": [ | |
{ | |
"output_type": "execute_result", | |
"data": { | |
"text/plain": [ | |
"[('1221', tensor(4.5355)),\n", | |
" ('50', tensor(4.4703)),\n", | |
" ('608', tensor(4.4115)),\n", | |
" ('858', tensor(4.4111)),\n", | |
" ('318', tensor(4.3977)),\n", | |
" ('778', tensor(4.3464)),\n", | |
" ('527', tensor(4.3296)),\n", | |
" ('1193', tensor(4.2798)),\n", | |
" ('47', tensor(4.2782)),\n", | |
" ('1210', tensor(4.2773))]" | |
] | |
}, | |
"metadata": { | |
"tags": [] | |
}, | |
"execution_count": 6 | |
} | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"id": "yneVsK7Bm51J", | |
"colab_type": "text" | |
}, | |
"source": [ | |
"We can also use `not_used=False` to include product recommendation that users have used before." | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "NlHdUfoUmg60", | |
"colab_type": "code", | |
"colab": { | |
"base_uri": "https://localhost:8080/", | |
"height": 187 | |
}, | |
"outputId": "2b2682f2-5004-4515-adb2-c11c81a6040b" | |
}, | |
"source": [ | |
"predict_topN_item_for_user(learn, 561, not_used=False)" | |
], | |
"execution_count": 7, | |
"outputs": [ | |
{ | |
"output_type": "execute_result", | |
"data": { | |
"text/plain": [ | |
"[('1221', tensor(4.5355)),\n", | |
" ('296', tensor(4.5167)),\n", | |
" ('50', tensor(4.4703)),\n", | |
" ('1196', tensor(4.4272)),\n", | |
" ('4973', tensor(4.4204)),\n", | |
" ('608', tensor(4.4115)),\n", | |
" ('858', tensor(4.4111)),\n", | |
" ('318', tensor(4.3977)),\n", | |
" ('1136', tensor(4.3617)),\n", | |
" ('1198', tensor(4.3557))]" | |
] | |
}, | |
"metadata": { | |
"tags": [] | |
}, | |
"execution_count": 7 | |
} | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"id": "rrtLY1UonZnc", | |
"colab_type": "text" | |
}, | |
"source": [ | |
"As explained earlier, we can also target users that may interest particular product. Typically it's used to generate users whitelist for to create a campaign for this product." | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "1Lokb10yl6Xw", | |
"colab_type": "code", | |
"colab": { | |
"base_uri": "https://localhost:8080/", | |
"height": 187 | |
}, | |
"outputId": "da321df5-bb6f-4da4-8865-9e70c8bc3260" | |
}, | |
"source": [ | |
"predict_topN_user_for_item(learn, \"1089\")" | |
], | |
"execution_count": 8, | |
"outputs": [ | |
{ | |
"output_type": "execute_result", | |
"data": { | |
"text/plain": [ | |
"[('242', tensor(5.0561)),\n", | |
" ('577', tensor(4.9409)),\n", | |
" ('128', tensor(4.7861)),\n", | |
" ('95', tensor(4.7580)),\n", | |
" ('431', tensor(4.7357)),\n", | |
" ('358', tensor(4.6373)),\n", | |
" ('268', tensor(4.6063)),\n", | |
" ('292', tensor(4.6043)),\n", | |
" ('598', tensor(4.5964)),\n", | |
" ('247', tensor(4.5890))]" | |
] | |
}, | |
"metadata": { | |
"tags": [] | |
}, | |
"execution_count": 8 | |
} | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "tKsAJXeEl9kA", | |
"colab_type": "code", | |
"colab": { | |
"base_uri": "https://localhost:8080/", | |
"height": 187 | |
}, | |
"outputId": "259e3b98-f4e0-47e0-c87e-6b381d75c373" | |
}, | |
"source": [ | |
"predict_topN_user_for_item(learn, \"1089\", not_used=False)" | |
], | |
"execution_count": 9, | |
"outputs": [ | |
{ | |
"output_type": "execute_result", | |
"data": { | |
"text/plain": [ | |
"[('242', tensor(5.0561)),\n", | |
" ('577', tensor(4.9409)),\n", | |
" ('654', tensor(4.9298)),\n", | |
" ('128', tensor(4.7861)),\n", | |
" ('95', tensor(4.7580)),\n", | |
" ('431', tensor(4.7357)),\n", | |
" ('480', tensor(4.7314)),\n", | |
" ('544', tensor(4.6986)),\n", | |
" ('358', tensor(4.6373)),\n", | |
" ('30', tensor(4.6316))]" | |
] | |
}, | |
"metadata": { | |
"tags": [] | |
}, | |
"execution_count": 9 | |
} | |
] | |
} | |
] | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Updated to fix bug for use case
predict_topN_user_for_item