Created
August 8, 2023 05:43
-
-
Save pxpc2/e6749d6b86c6eae19b3f560263fdfb07 to your computer and use it in GitHub Desktop.
pedr-wine-rec.ipynb
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"nbformat": 4, | |
"nbformat_minor": 0, | |
"metadata": { | |
"colab": { | |
"provenance": [], | |
"authorship_tag": "ABX9TyMkfkdLJ0GEJ/qMgTH0NYHe" | |
}, | |
"kernelspec": { | |
"name": "python3", | |
"display_name": "Python 3" | |
}, | |
"language_info": { | |
"name": "python" | |
} | |
}, | |
"cells": [ | |
{ | |
"cell_type": "code", | |
"execution_count": 2, | |
"metadata": { | |
"id": "3mgsncFvY9ks" | |
}, | |
"outputs": [], | |
"source": [ | |
"! [ -e /content ] && pip install -Uqq fastbook\n", | |
"import fastbook\n", | |
"fastbook.setup_book()\n", | |
"\n", | |
"from fastbook import *\n", | |
"from fastai.collab import *\n", | |
"from fastai.tabular.all import *" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"wines = pd.read_csv('/content/XWines_Slim_1K_wines.csv',\n", | |
" usecols=['WineID', 'WineName', 'Type', 'Grapes',\n", | |
" 'Body', 'Harmonize', 'Code'],\n", | |
" delimiter=',', header=None, skiprows=1,\n", | |
" names=['WineID', 'WineName', 'Type', 'Elaborate', 'Grapes',\n", | |
" 'Harmonize', 'ABV', 'Body', 'Acidity', 'Code',\n", | |
" 'Country', 'RegionID', 'RegionName', 'WineryID',\n", | |
" 'WineryName','Website', 'Vintages'],\n", | |
" dtype={'WineID': int, 'WineName': str, 'Type': str,\n", | |
" 'Grapes': str, 'Harmonize': str, 'Body': str,\n", | |
" 'Code': str})\n", | |
"\n", | |
"\n", | |
"\n", | |
"ratings = pd.read_csv('/content/XWines_Slim_150K_ratings.csv', usecols=['UserID', 'WineID', 'Rating'], delimiter=',', header=None, skiprows=1,\n", | |
" names=['RatingID', 'UserID', 'WineID', 'Vintage', 'Rating', 'Date'],\n", | |
" dtype={'UserID': int, 'WineID': int, 'Rating': float})" | |
], | |
"metadata": { | |
"id": "2PeMkTLzazqw" | |
}, | |
"execution_count": 84, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"ratings = ratings.merge(wines)\n", | |
"ratings.head()" | |
], | |
"metadata": { | |
"colab": { | |
"base_uri": "https://localhost:8080/", | |
"height": 206 | |
}, | |
"id": "3KL0e4GWZ2ct", | |
"outputId": "3bc692a3-4d3e-4450-bc57-cd0071d0bbd5" | |
}, | |
"execution_count": 85, | |
"outputs": [ | |
{ | |
"output_type": "execute_result", | |
"data": { | |
"text/plain": [ | |
" UserID WineID Rating WineName Type \\\n", | |
"0 1356810 103471 4.5 Presidential Colheita Port Dessert/Port \n", | |
"1 1212696 103471 4.5 Presidential Colheita Port Dessert/Port \n", | |
"2 1653207 103471 4.0 Presidential Colheita Port Dessert/Port \n", | |
"3 1314309 103471 4.0 Presidential Colheita Port Dessert/Port \n", | |
"4 1201703 103471 4.0 Presidential Colheita Port Dessert/Port \n", | |
"\n", | |
" Grapes \\\n", | |
"0 ['Touriga Nacional', 'Tinta Roriz'] \n", | |
"1 ['Touriga Nacional', 'Tinta Roriz'] \n", | |
"2 ['Touriga Nacional', 'Tinta Roriz'] \n", | |
"3 ['Touriga Nacional', 'Tinta Roriz'] \n", | |
"4 ['Touriga Nacional', 'Tinta Roriz'] \n", | |
"\n", | |
" Harmonize Body Code \n", | |
"0 ['Beef', 'Maturated Cheese', 'Hard Cheese'] Very full-bodied PT \n", | |
"1 ['Beef', 'Maturated Cheese', 'Hard Cheese'] Very full-bodied PT \n", | |
"2 ['Beef', 'Maturated Cheese', 'Hard Cheese'] Very full-bodied PT \n", | |
"3 ['Beef', 'Maturated Cheese', 'Hard Cheese'] Very full-bodied PT \n", | |
"4 ['Beef', 'Maturated Cheese', 'Hard Cheese'] Very full-bodied PT " | |
], | |
"text/html": [ | |
"\n", | |
"\n", | |
" <div id=\"df-e1519468-01b4-4ee5-8ab9-fd43d42fef20\">\n", | |
" <div class=\"colab-df-container\">\n", | |
" <div>\n", | |
"<style scoped>\n", | |
" .dataframe tbody tr th:only-of-type {\n", | |
" vertical-align: middle;\n", | |
" }\n", | |
"\n", | |
" .dataframe tbody tr th {\n", | |
" vertical-align: top;\n", | |
" }\n", | |
"\n", | |
" .dataframe thead th {\n", | |
" text-align: right;\n", | |
" }\n", | |
"</style>\n", | |
"<table border=\"1\" class=\"dataframe\">\n", | |
" <thead>\n", | |
" <tr style=\"text-align: right;\">\n", | |
" <th></th>\n", | |
" <th>UserID</th>\n", | |
" <th>WineID</th>\n", | |
" <th>Rating</th>\n", | |
" <th>WineName</th>\n", | |
" <th>Type</th>\n", | |
" <th>Grapes</th>\n", | |
" <th>Harmonize</th>\n", | |
" <th>Body</th>\n", | |
" <th>Code</th>\n", | |
" </tr>\n", | |
" </thead>\n", | |
" <tbody>\n", | |
" <tr>\n", | |
" <th>0</th>\n", | |
" <td>1356810</td>\n", | |
" <td>103471</td>\n", | |
" <td>4.5</td>\n", | |
" <td>Presidential Colheita Port</td>\n", | |
" <td>Dessert/Port</td>\n", | |
" <td>['Touriga Nacional', 'Tinta Roriz']</td>\n", | |
" <td>['Beef', 'Maturated Cheese', 'Hard Cheese']</td>\n", | |
" <td>Very full-bodied</td>\n", | |
" <td>PT</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>1</th>\n", | |
" <td>1212696</td>\n", | |
" <td>103471</td>\n", | |
" <td>4.5</td>\n", | |
" <td>Presidential Colheita Port</td>\n", | |
" <td>Dessert/Port</td>\n", | |
" <td>['Touriga Nacional', 'Tinta Roriz']</td>\n", | |
" <td>['Beef', 'Maturated Cheese', 'Hard Cheese']</td>\n", | |
" <td>Very full-bodied</td>\n", | |
" <td>PT</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>2</th>\n", | |
" <td>1653207</td>\n", | |
" <td>103471</td>\n", | |
" <td>4.0</td>\n", | |
" <td>Presidential Colheita Port</td>\n", | |
" <td>Dessert/Port</td>\n", | |
" <td>['Touriga Nacional', 'Tinta Roriz']</td>\n", | |
" <td>['Beef', 'Maturated Cheese', 'Hard Cheese']</td>\n", | |
" <td>Very full-bodied</td>\n", | |
" <td>PT</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>3</th>\n", | |
" <td>1314309</td>\n", | |
" <td>103471</td>\n", | |
" <td>4.0</td>\n", | |
" <td>Presidential Colheita Port</td>\n", | |
" <td>Dessert/Port</td>\n", | |
" <td>['Touriga Nacional', 'Tinta Roriz']</td>\n", | |
" <td>['Beef', 'Maturated Cheese', 'Hard Cheese']</td>\n", | |
" <td>Very full-bodied</td>\n", | |
" <td>PT</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>4</th>\n", | |
" <td>1201703</td>\n", | |
" <td>103471</td>\n", | |
" <td>4.0</td>\n", | |
" <td>Presidential Colheita Port</td>\n", | |
" <td>Dessert/Port</td>\n", | |
" <td>['Touriga Nacional', 'Tinta Roriz']</td>\n", | |
" <td>['Beef', 'Maturated Cheese', 'Hard Cheese']</td>\n", | |
" <td>Very full-bodied</td>\n", | |
" <td>PT</td>\n", | |
" </tr>\n", | |
" </tbody>\n", | |
"</table>\n", | |
"</div>\n", | |
" <button class=\"colab-df-convert\" onclick=\"convertToInteractive('df-e1519468-01b4-4ee5-8ab9-fd43d42fef20')\"\n", | |
" title=\"Convert this dataframe to an interactive table.\"\n", | |
" style=\"display:none;\">\n", | |
"\n", | |
" <svg xmlns=\"http://www.w3.org/2000/svg\" height=\"24px\"viewBox=\"0 0 24 24\"\n", | |
" width=\"24px\">\n", | |
" <path d=\"M0 0h24v24H0V0z\" fill=\"none\"/>\n", | |
" <path d=\"M18.56 5.44l.94 2.06.94-2.06 2.06-.94-2.06-.94-.94-2.06-.94 2.06-2.06.94zm-11 1L8.5 8.5l.94-2.06 2.06-.94-2.06-.94L8.5 2.5l-.94 2.06-2.06.94zm10 10l.94 2.06.94-2.06 2.06-.94-2.06-.94-.94-2.06-.94 2.06-2.06.94z\"/><path d=\"M17.41 7.96l-1.37-1.37c-.4-.4-.92-.59-1.43-.59-.52 0-1.04.2-1.43.59L10.3 9.45l-7.72 7.72c-.78.78-.78 2.05 0 2.83L4 21.41c.39.39.9.59 1.41.59.51 0 1.02-.2 1.41-.59l7.78-7.78 2.81-2.81c.8-.78.8-2.07 0-2.86zM5.41 20L4 18.59l7.72-7.72 1.47 1.35L5.41 20z\"/>\n", | |
" </svg>\n", | |
" </button>\n", | |
"\n", | |
"\n", | |
"\n", | |
" <div id=\"df-9cd9b989-6357-47bc-a5f1-127336615be6\">\n", | |
" <button class=\"colab-df-quickchart\" onclick=\"quickchart('df-9cd9b989-6357-47bc-a5f1-127336615be6')\"\n", | |
" title=\"Suggest charts.\"\n", | |
" style=\"display:none;\">\n", | |
"\n", | |
"<svg xmlns=\"http://www.w3.org/2000/svg\" height=\"24px\"viewBox=\"0 0 24 24\"\n", | |
" width=\"24px\">\n", | |
" <g>\n", | |
" <path d=\"M19 3H5c-1.1 0-2 .9-2 2v14c0 1.1.9 2 2 2h14c1.1 0 2-.9 2-2V5c0-1.1-.9-2-2-2zM9 17H7v-7h2v7zm4 0h-2V7h2v10zm4 0h-2v-4h2v4z\"/>\n", | |
" </g>\n", | |
"</svg>\n", | |
" </button>\n", | |
" </div>\n", | |
"\n", | |
"<style>\n", | |
" .colab-df-quickchart {\n", | |
" background-color: #E8F0FE;\n", | |
" border: none;\n", | |
" border-radius: 50%;\n", | |
" cursor: pointer;\n", | |
" display: none;\n", | |
" fill: #1967D2;\n", | |
" height: 32px;\n", | |
" padding: 0 0 0 0;\n", | |
" width: 32px;\n", | |
" }\n", | |
"\n", | |
" .colab-df-quickchart:hover {\n", | |
" background-color: #E2EBFA;\n", | |
" box-shadow: 0px 1px 2px rgba(60, 64, 67, 0.3), 0px 1px 3px 1px rgba(60, 64, 67, 0.15);\n", | |
" fill: #174EA6;\n", | |
" }\n", | |
"\n", | |
" [theme=dark] .colab-df-quickchart {\n", | |
" background-color: #3B4455;\n", | |
" fill: #D2E3FC;\n", | |
" }\n", | |
"\n", | |
" [theme=dark] .colab-df-quickchart:hover {\n", | |
" background-color: #434B5C;\n", | |
" box-shadow: 0px 1px 3px 1px rgba(0, 0, 0, 0.15);\n", | |
" filter: drop-shadow(0px 1px 2px rgba(0, 0, 0, 0.3));\n", | |
" fill: #FFFFFF;\n", | |
" }\n", | |
"</style>\n", | |
"\n", | |
" <script>\n", | |
" async function quickchart(key) {\n", | |
" const containerElement = document.querySelector('#' + key);\n", | |
" const charts = await google.colab.kernel.invokeFunction(\n", | |
" 'suggestCharts', [key], {});\n", | |
" }\n", | |
" </script>\n", | |
"\n", | |
" <script>\n", | |
"\n", | |
"function displayQuickchartButton(domScope) {\n", | |
" let quickchartButtonEl =\n", | |
" domScope.querySelector('#df-9cd9b989-6357-47bc-a5f1-127336615be6 button.colab-df-quickchart');\n", | |
" quickchartButtonEl.style.display =\n", | |
" google.colab.kernel.accessAllowed ? 'block' : 'none';\n", | |
"}\n", | |
"\n", | |
" displayQuickchartButton(document);\n", | |
" </script>\n", | |
" <style>\n", | |
" .colab-df-container {\n", | |
" display:flex;\n", | |
" flex-wrap:wrap;\n", | |
" gap: 12px;\n", | |
" }\n", | |
"\n", | |
" .colab-df-convert {\n", | |
" background-color: #E8F0FE;\n", | |
" border: none;\n", | |
" border-radius: 50%;\n", | |
" cursor: pointer;\n", | |
" display: none;\n", | |
" fill: #1967D2;\n", | |
" height: 32px;\n", | |
" padding: 0 0 0 0;\n", | |
" width: 32px;\n", | |
" }\n", | |
"\n", | |
" .colab-df-convert:hover {\n", | |
" background-color: #E2EBFA;\n", | |
" box-shadow: 0px 1px 2px rgba(60, 64, 67, 0.3), 0px 1px 3px 1px rgba(60, 64, 67, 0.15);\n", | |
" fill: #174EA6;\n", | |
" }\n", | |
"\n", | |
" [theme=dark] .colab-df-convert {\n", | |
" background-color: #3B4455;\n", | |
" fill: #D2E3FC;\n", | |
" }\n", | |
"\n", | |
" [theme=dark] .colab-df-convert:hover {\n", | |
" background-color: #434B5C;\n", | |
" box-shadow: 0px 1px 3px 1px rgba(0, 0, 0, 0.15);\n", | |
" filter: drop-shadow(0px 1px 2px rgba(0, 0, 0, 0.3));\n", | |
" fill: #FFFFFF;\n", | |
" }\n", | |
" </style>\n", | |
"\n", | |
" <script>\n", | |
" const buttonEl =\n", | |
" document.querySelector('#df-e1519468-01b4-4ee5-8ab9-fd43d42fef20 button.colab-df-convert');\n", | |
" buttonEl.style.display =\n", | |
" google.colab.kernel.accessAllowed ? 'block' : 'none';\n", | |
"\n", | |
" async function convertToInteractive(key) {\n", | |
" const element = document.querySelector('#df-e1519468-01b4-4ee5-8ab9-fd43d42fef20');\n", | |
" const dataTable =\n", | |
" await google.colab.kernel.invokeFunction('convertToInteractive',\n", | |
" [key], {});\n", | |
" if (!dataTable) return;\n", | |
"\n", | |
" const docLinkHtml = 'Like what you see? Visit the ' +\n", | |
" '<a target=\"_blank\" href=https://colab.research.google.com/notebooks/data_table.ipynb>data table notebook</a>'\n", | |
" + ' to learn more about interactive tables.';\n", | |
" element.innerHTML = '';\n", | |
" dataTable['output_type'] = 'display_data';\n", | |
" await google.colab.output.renderOutput(dataTable, element);\n", | |
" const docLink = document.createElement('div');\n", | |
" docLink.innerHTML = docLinkHtml;\n", | |
" element.appendChild(docLink);\n", | |
" }\n", | |
" </script>\n", | |
" </div>\n", | |
" </div>\n" | |
] | |
}, | |
"metadata": {}, | |
"execution_count": 85 | |
} | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"dls = CollabDataLoaders.from_df(ratings, item_name='WineName', bs=64)\n", | |
"dls.show_batch()" | |
], | |
"metadata": { | |
"colab": { | |
"base_uri": "https://localhost:8080/", | |
"height": 363 | |
}, | |
"id": "lnhX4Sa_Z0Kl", | |
"outputId": "20be8b31-2298-4643-92ed-3da944f19af4" | |
}, | |
"execution_count": 86, | |
"outputs": [ | |
{ | |
"output_type": "display_data", | |
"data": { | |
"text/plain": [ | |
"<IPython.core.display.HTML object>" | |
], | |
"text/html": [ | |
"<table border=\"1\" class=\"dataframe\">\n", | |
" <thead>\n", | |
" <tr style=\"text-align: right;\">\n", | |
" <th></th>\n", | |
" <th>UserID</th>\n", | |
" <th>WineName</th>\n", | |
" <th>Rating</th>\n", | |
" </tr>\n", | |
" </thead>\n", | |
" <tbody>\n", | |
" <tr>\n", | |
" <th>0</th>\n", | |
" <td>1103573</td>\n", | |
" <td>Crianza</td>\n", | |
" <td>3.5</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>1</th>\n", | |
" <td>1289776</td>\n", | |
" <td>Saint-Julien (Grand Cru Classé)</td>\n", | |
" <td>4.0</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>2</th>\n", | |
" <td>1208908</td>\n", | |
" <td>Pedro Ximénez Murillo Selección del Centenario</td>\n", | |
" <td>4.0</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>3</th>\n", | |
" <td>1156121</td>\n", | |
" <td>Koonunga Hill Shiraz-Cabernet</td>\n", | |
" <td>3.5</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>4</th>\n", | |
" <td>1176768</td>\n", | |
" <td>Grand Vintage Brut Champagne</td>\n", | |
" <td>4.5</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>5</th>\n", | |
" <td>1360720</td>\n", | |
" <td>Blu Prosecco Extra Dry</td>\n", | |
" <td>4.0</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>6</th>\n", | |
" <td>1394003</td>\n", | |
" <td>Red Blend</td>\n", | |
" <td>3.0</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>7</th>\n", | |
" <td>1153307</td>\n", | |
" <td>Esporão Reserva Tinto</td>\n", | |
" <td>3.5</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>8</th>\n", | |
" <td>1378235</td>\n", | |
" <td>Riesling</td>\n", | |
" <td>4.0</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>9</th>\n", | |
" <td>1128562</td>\n", | |
" <td>Blanc de Blancs Brut Champagne</td>\n", | |
" <td>5.0</td>\n", | |
" </tr>\n", | |
" </tbody>\n", | |
"</table>" | |
] | |
}, | |
"metadata": {} | |
} | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"n_users = len(dls.classes['UserID'])\n", | |
"n_wines = len(dls.classes['WineName'])\n", | |
"n_factors = 5\n", | |
"\n", | |
"user_factors = torch.randn(n_users, n_factors)\n", | |
"wine_factors = torch.randn(n_wines, n_factors)" | |
], | |
"metadata": { | |
"id": "kjlNMfvxerkM" | |
}, | |
"execution_count": 87, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"class DotProduct(Module):\n", | |
" def __init__(self, n_users, n_wines, n_factors, y_range=(0,5.5)):\n", | |
" self.user_factors = Embedding(n_users, n_factors)\n", | |
" self.user_bias = Embedding(n_users, 1)\n", | |
" self.wine_factors = Embedding(n_wines, n_factors)\n", | |
" self.wine_bias = Embedding(n_wines, 1)\n", | |
" self.y_range = y_range\n", | |
"\n", | |
" def forward(self, x):\n", | |
" users = self.user_factors(x[:,0])\n", | |
" wines = self.wine_factors(x[:,1])\n", | |
" res = (users * wines).sum(dim=1, keepdim=True)\n", | |
" res += self.user_bias(x[:,0]) + self.wine_bias(x[:,1])\n", | |
" return sigmoid_range(res, *self.y_range)" | |
], | |
"metadata": { | |
"id": "Z48J8LR_h5J6" | |
}, | |
"execution_count": 103, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"model = DotProduct(n_users, n_wines, 50)\n", | |
"learn = Learner(dls, model, loss_func=MSELossFlat())\n", | |
"learn.fit_one_cycle(5, 5e-3)" | |
], | |
"metadata": { | |
"colab": { | |
"base_uri": "https://localhost:8080/", | |
"height": 206 | |
}, | |
"id": "YGEdQKv7fFTh", | |
"outputId": "3ee372dc-3a94-4f5c-f5d4-e7ba8e117d8d" | |
}, | |
"execution_count": 106, | |
"outputs": [ | |
{ | |
"output_type": "display_data", | |
"data": { | |
"text/plain": [ | |
"<IPython.core.display.HTML object>" | |
], | |
"text/html": [ | |
"\n", | |
"<style>\n", | |
" /* Turns off some styling */\n", | |
" progress {\n", | |
" /* gets rid of default border in Firefox and Opera. */\n", | |
" border: none;\n", | |
" /* Needs to be in here for Safari polyfill so background images work as expected. */\n", | |
" background-size: auto;\n", | |
" }\n", | |
" progress:not([value]), progress:not([value])::-webkit-progress-bar {\n", | |
" background: repeating-linear-gradient(45deg, #7e7e7e, #7e7e7e 10px, #5c5c5c 10px, #5c5c5c 20px);\n", | |
" }\n", | |
" .progress-bar-interrupted, .progress-bar-interrupted::-webkit-progress-bar {\n", | |
" background: #F44336;\n", | |
" }\n", | |
"</style>\n" | |
] | |
}, | |
"metadata": {} | |
}, | |
{ | |
"output_type": "display_data", | |
"data": { | |
"text/plain": [ | |
"<IPython.core.display.HTML object>" | |
], | |
"text/html": [ | |
"<table border=\"1\" class=\"dataframe\">\n", | |
" <thead>\n", | |
" <tr style=\"text-align: left;\">\n", | |
" <th>epoch</th>\n", | |
" <th>train_loss</th>\n", | |
" <th>valid_loss</th>\n", | |
" <th>time</th>\n", | |
" </tr>\n", | |
" </thead>\n", | |
" <tbody>\n", | |
" <tr>\n", | |
" <td>0</td>\n", | |
" <td>0.324220</td>\n", | |
" <td>0.305056</td>\n", | |
" <td>00:27</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <td>1</td>\n", | |
" <td>0.255239</td>\n", | |
" <td>0.274921</td>\n", | |
" <td>00:30</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <td>2</td>\n", | |
" <td>0.161736</td>\n", | |
" <td>0.289275</td>\n", | |
" <td>00:28</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <td>3</td>\n", | |
" <td>0.086282</td>\n", | |
" <td>0.291900</td>\n", | |
" <td>00:28</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <td>4</td>\n", | |
" <td>0.043770</td>\n", | |
" <td>0.294777</td>\n", | |
" <td>00:28</td>\n", | |
" </tr>\n", | |
" </tbody>\n", | |
"</table>" | |
] | |
}, | |
"metadata": {} | |
} | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"embs = get_emb_sz(dls)\n", | |
"embs" | |
], | |
"metadata": { | |
"colab": { | |
"base_uri": "https://localhost:8080/" | |
}, | |
"id": "HL4M47cToTUg", | |
"outputId": "f4757205-16bd-49a6-d042-3b963dd2e47f" | |
}, | |
"execution_count": 108, | |
"outputs": [ | |
{ | |
"output_type": "execute_result", | |
"data": { | |
"text/plain": [ | |
"[(10507, 286), (804, 68)]" | |
] | |
}, | |
"metadata": {}, | |
"execution_count": 108 | |
} | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"class CollabNN(Module):\n", | |
" def __init__(self, user_sz, item_sz, y_range=(0,5.5), n_act=100):\n", | |
" self.user_factors = Embedding(*user_sz)\n", | |
" self.item_factors = Embedding(*item_sz)\n", | |
" self.layers = nn.Sequential(\n", | |
" nn.Linear(user_sz[1]+item_sz[1], n_act),\n", | |
" nn.ReLU(),\n", | |
" nn.Linear(n_act, 1))\n", | |
" self.y_range = y_range\n", | |
"\n", | |
" def forward(self, x):\n", | |
" embs = self.user_factors(x[:,0]),self.item_factors(x[:,1])\n", | |
" x = self.layers(torch.cat(embs, dim=1))\n", | |
" return sigmoid_range(x, *self.y_range)" | |
], | |
"metadata": { | |
"id": "43TPrOSuoZ8z" | |
}, | |
"execution_count": 109, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"model = CollabNN(*embs)" | |
], | |
"metadata": { | |
"id": "auFEYOgHocjq" | |
}, | |
"execution_count": 110, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"learn = collab_learner(dls, use_nn=True, y_range=(0, 5.5), layers=[100,50])\n", | |
"learn.fit_one_cycle(5, 5e-3, wd=0.1)" | |
], | |
"metadata": { | |
"colab": { | |
"base_uri": "https://localhost:8080/", | |
"height": 206 | |
}, | |
"id": "xTKnuKdMofKE", | |
"outputId": "f813b759-e079-4063-89c2-020f8c26ef26" | |
}, | |
"execution_count": 112, | |
"outputs": [ | |
{ | |
"output_type": "display_data", | |
"data": { | |
"text/plain": [ | |
"<IPython.core.display.HTML object>" | |
], | |
"text/html": [ | |
"\n", | |
"<style>\n", | |
" /* Turns off some styling */\n", | |
" progress {\n", | |
" /* gets rid of default border in Firefox and Opera. */\n", | |
" border: none;\n", | |
" /* Needs to be in here for Safari polyfill so background images work as expected. */\n", | |
" background-size: auto;\n", | |
" }\n", | |
" progress:not([value]), progress:not([value])::-webkit-progress-bar {\n", | |
" background: repeating-linear-gradient(45deg, #7e7e7e, #7e7e7e 10px, #5c5c5c 10px, #5c5c5c 20px);\n", | |
" }\n", | |
" .progress-bar-interrupted, .progress-bar-interrupted::-webkit-progress-bar {\n", | |
" background: #F44336;\n", | |
" }\n", | |
"</style>\n" | |
] | |
}, | |
"metadata": {} | |
}, | |
{ | |
"output_type": "display_data", | |
"data": { | |
"text/plain": [ | |
"<IPython.core.display.HTML object>" | |
], | |
"text/html": [ | |
"<table border=\"1\" class=\"dataframe\">\n", | |
" <thead>\n", | |
" <tr style=\"text-align: left;\">\n", | |
" <th>epoch</th>\n", | |
" <th>train_loss</th>\n", | |
" <th>valid_loss</th>\n", | |
" <th>time</th>\n", | |
" </tr>\n", | |
" </thead>\n", | |
" <tbody>\n", | |
" <tr>\n", | |
" <td>0</td>\n", | |
" <td>0.293264</td>\n", | |
" <td>0.282452</td>\n", | |
" <td>02:05</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <td>1</td>\n", | |
" <td>0.262291</td>\n", | |
" <td>0.260737</td>\n", | |
" <td>01:33</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <td>2</td>\n", | |
" <td>0.249280</td>\n", | |
" <td>0.254707</td>\n", | |
" <td>01:32</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <td>3</td>\n", | |
" <td>0.228655</td>\n", | |
" <td>0.253738</td>\n", | |
" <td>01:28</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <td>4</td>\n", | |
" <td>0.188885</td>\n", | |
" <td>0.266842</td>\n", | |
" <td>01:31</td>\n", | |
" </tr>\n", | |
" </tbody>\n", | |
"</table>" | |
] | |
}, | |
"metadata": {} | |
} | |
] | |
} | |
] | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment