Skip to content

Instantly share code, notes, and snippets.

@Chris-hughes10
Created December 9, 2021 09:19
Show Gist options
  • Save Chris-hughes10/162c1009c4d42ec38dc4f008583cf560 to your computer and use it in GitHub Desktop.
Save Chris-hughes10/162c1009c4d42ec38dc4f008583cf560 to your computer and use it in GitHub Desktop.
Comparing matrix factorixation with transformers using pytorch-accelerated blog post.ipynb
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"metadata": {},
"cell_type": "markdown",
"source": "# Comparing matrix factorization with transformers for MovieLens recommendations using PyTorch-accelerated."
},
{
"metadata": {},
"cell_type": "markdown",
"source": "By Chris Hughes"
},
{
"metadata": {},
"cell_type": "markdown",
"source": "The package versions used are:"
},
{
"metadata": {
"trusted": false
},
"cell_type": "code",
"source": "torch==1.10.0\ntorchmetrics==0.6.0\npytorch-accelerated==0.1.7",
"execution_count": null,
"outputs": []
},
{
"metadata": {
"trusted": false
},
"cell_type": "code",
"source": "from pathlib import Path\n\nimport numpy as np\nimport pandas as pd\nfrom statsmodels.distributions.empirical_distribution import ECDF\nimport matplotlib.pyplot as plt",
"execution_count": 1,
"outputs": []
},
{
"metadata": {},
"cell_type": "markdown",
"source": "## Prepare Data"
},
{
"metadata": {},
"cell_type": "markdown",
"source": "For our dataset, we shall use MovieLens-1M, a collection of one million ratings from 6000 users on 4000 movies. This dataset was collected and is maintained by GroupLens, a research group at the University of Minnesota, and released in 2003; it has been frequently used in the Machine Learning community and is commonly presented as a benchmark in academic papers."
},
{
"metadata": {},
"cell_type": "markdown",
"source": "### Download Movielens-1M dataset"
},
{
"metadata": {
"trusted": false
},
"cell_type": "code",
"source": "!wget http://files.grouplens.org/datasets/movielens/ml-1m.zip",
"execution_count": 2,
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": "--2021-12-05 11:22:27-- http://files.grouplens.org/datasets/movielens/ml-1m.zip\nResolving files.grouplens.org... 128.101.65.152\nConnecting to files.grouplens.org|128.101.65.152|:80... connected.\nHTTP request sent, awaiting response... 200 OK\nLength: 5917549 (5.6M) [application/zip]\nSaving to: ‘ml-1m.zip.1’\n\nml-1m.zip.1 100%[===================>] 5.64M 6.56MB/s in 0.9s \n\n2021-12-05 11:22:28 (6.56 MB/s) - ‘ml-1m.zip.1’ saved [5917549/5917549]\n\n"
}
]
},
{
"metadata": {
"trusted": false
},
"cell_type": "code",
"source": "!unzip ml-1m.zip",
"execution_count": 3,
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": "Archive: ml-1m.zip\nreplace ml-1m/movies.dat? [y]es, [n]o, [A]ll, [N]one, [r]ename: ^C\n"
}
]
},
{
"metadata": {},
"cell_type": "markdown",
"source": "## Load Data"
},
{
"metadata": {},
"cell_type": "markdown",
"source": "MovieLens consists of three files, 'movies.dat', 'users.dat', and 'ratings.dat', which have the following formats:"
},
{
"metadata": {
"trusted": false
},
"cell_type": "code",
"source": "dataset_path = Path('ml-1m')",
"execution_count": 2,
"outputs": []
},
{
"metadata": {
"trusted": false
},
"cell_type": "code",
"source": "users = pd.read_csv(\n dataset_path/\"users.dat\",\n sep=\"::\",\n names=[\"user_id\", \"sex\", \"age_group\", \"occupation\", \"zip_code\"],\n encoding='latin-1',\n engine='python'\n)\n\nratings = pd.read_csv(\n dataset_path/\"ratings.dat\",\n sep=\"::\",\n names=[\"user_id\", \"movie_id\", \"rating\", \"unix_timestamp\"],\n encoding='latin-1',\n engine='python'\n)\n\nmovies = pd.read_csv(\n dataset_path/\"movies.dat\", sep=\"::\", names=[\"movie_id\", \"title\", \"genres\"],\n encoding='latin-1',\n engine='python'\n)\n",
"execution_count": 3,
"outputs": []
},
{
"metadata": {
"trusted": false
},
"cell_type": "code",
"source": "users",
"execution_count": 4,
"outputs": [
{
"data": {
"text/html": "<div>\n<style scoped>\n .dataframe tbody tr th:only-of-type {\n vertical-align: middle;\n }\n\n .dataframe tbody tr th {\n vertical-align: top;\n }\n\n .dataframe thead th {\n text-align: right;\n }\n</style>\n<table border=\"1\" class=\"dataframe\">\n <thead>\n <tr style=\"text-align: right;\">\n <th></th>\n <th>user_id</th>\n <th>sex</th>\n <th>age_group</th>\n <th>occupation</th>\n <th>zip_code</th>\n </tr>\n </thead>\n <tbody>\n <tr>\n <th>0</th>\n <td>1</td>\n <td>F</td>\n <td>1</td>\n <td>10</td>\n <td>48067</td>\n </tr>\n <tr>\n <th>1</th>\n <td>2</td>\n <td>M</td>\n <td>56</td>\n <td>16</td>\n <td>70072</td>\n </tr>\n <tr>\n <th>2</th>\n <td>3</td>\n <td>M</td>\n <td>25</td>\n <td>15</td>\n <td>55117</td>\n </tr>\n <tr>\n <th>3</th>\n <td>4</td>\n <td>M</td>\n <td>45</td>\n <td>7</td>\n <td>02460</td>\n </tr>\n <tr>\n <th>4</th>\n <td>5</td>\n <td>M</td>\n <td>25</td>\n <td>20</td>\n <td>55455</td>\n </tr>\n <tr>\n <th>...</th>\n <td>...</td>\n <td>...</td>\n <td>...</td>\n <td>...</td>\n <td>...</td>\n </tr>\n <tr>\n <th>6035</th>\n <td>6036</td>\n <td>F</td>\n <td>25</td>\n <td>15</td>\n <td>32603</td>\n </tr>\n <tr>\n <th>6036</th>\n <td>6037</td>\n <td>F</td>\n <td>45</td>\n <td>1</td>\n <td>76006</td>\n </tr>\n <tr>\n <th>6037</th>\n <td>6038</td>\n <td>F</td>\n <td>56</td>\n <td>1</td>\n <td>14706</td>\n </tr>\n <tr>\n <th>6038</th>\n <td>6039</td>\n <td>F</td>\n <td>45</td>\n <td>0</td>\n <td>01060</td>\n </tr>\n <tr>\n <th>6039</th>\n <td>6040</td>\n <td>M</td>\n <td>25</td>\n <td>6</td>\n <td>11106</td>\n </tr>\n </tbody>\n</table>\n<p>6040 rows × 5 columns</p>\n</div>",
"text/plain": " user_id sex age_group occupation zip_code\n0 1 F 1 10 48067\n1 2 M 56 16 70072\n2 3 M 25 15 55117\n3 4 M 45 7 02460\n4 5 M 25 20 55455\n... ... .. ... ... ...\n6035 6036 F 25 15 32603\n6036 6037 F 45 1 76006\n6037 6038 F 56 1 14706\n6038 6039 F 45 0 01060\n6039 6040 M 25 6 11106\n\n[6040 rows x 5 columns]"
},
"execution_count": 4,
"metadata": {},
"output_type": "execute_result"
}
]
},
{
"metadata": {
"trusted": false
},
"cell_type": "code",
"source": "movies",
"execution_count": 5,
"outputs": [
{
"data": {
"text/html": "<div>\n<style scoped>\n .dataframe tbody tr th:only-of-type {\n vertical-align: middle;\n }\n\n .dataframe tbody tr th {\n vertical-align: top;\n }\n\n .dataframe thead th {\n text-align: right;\n }\n</style>\n<table border=\"1\" class=\"dataframe\">\n <thead>\n <tr style=\"text-align: right;\">\n <th></th>\n <th>movie_id</th>\n <th>title</th>\n <th>genres</th>\n </tr>\n </thead>\n <tbody>\n <tr>\n <th>0</th>\n <td>1</td>\n <td>Toy Story (1995)</td>\n <td>Animation|Children's|Comedy</td>\n </tr>\n <tr>\n <th>1</th>\n <td>2</td>\n <td>Jumanji (1995)</td>\n <td>Adventure|Children's|Fantasy</td>\n </tr>\n <tr>\n <th>2</th>\n <td>3</td>\n <td>Grumpier Old Men (1995)</td>\n <td>Comedy|Romance</td>\n </tr>\n <tr>\n <th>3</th>\n <td>4</td>\n <td>Waiting to Exhale (1995)</td>\n <td>Comedy|Drama</td>\n </tr>\n <tr>\n <th>4</th>\n <td>5</td>\n <td>Father of the Bride Part II (1995)</td>\n <td>Comedy</td>\n </tr>\n <tr>\n <th>...</th>\n <td>...</td>\n <td>...</td>\n <td>...</td>\n </tr>\n <tr>\n <th>3878</th>\n <td>3948</td>\n <td>Meet the Parents (2000)</td>\n <td>Comedy</td>\n </tr>\n <tr>\n <th>3879</th>\n <td>3949</td>\n <td>Requiem for a Dream (2000)</td>\n <td>Drama</td>\n </tr>\n <tr>\n <th>3880</th>\n <td>3950</td>\n <td>Tigerland (2000)</td>\n <td>Drama</td>\n </tr>\n <tr>\n <th>3881</th>\n <td>3951</td>\n <td>Two Family House (2000)</td>\n <td>Drama</td>\n </tr>\n <tr>\n <th>3882</th>\n <td>3952</td>\n <td>Contender, The (2000)</td>\n <td>Drama|Thriller</td>\n </tr>\n </tbody>\n</table>\n<p>3883 rows × 3 columns</p>\n</div>",
"text/plain": " movie_id title \\\n0 1 Toy Story (1995) \n1 2 Jumanji (1995) \n2 3 Grumpier Old Men (1995) \n3 4 Waiting to Exhale (1995) \n4 5 Father of the Bride Part II (1995) \n... ... ... \n3878 3948 Meet the Parents (2000) \n3879 3949 Requiem for a Dream (2000) \n3880 3950 Tigerland (2000) \n3881 3951 Two Family House (2000) \n3882 3952 Contender, The (2000) \n\n genres \n0 Animation|Children's|Comedy \n1 Adventure|Children's|Fantasy \n2 Comedy|Romance \n3 Comedy|Drama \n4 Comedy \n... ... \n3878 Comedy \n3879 Drama \n3880 Drama \n3881 Drama \n3882 Drama|Thriller \n\n[3883 rows x 3 columns]"
},
"execution_count": 5,
"metadata": {},
"output_type": "execute_result"
}
]
},
{
"metadata": {
"trusted": false
},
"cell_type": "code",
"source": "ratings",
"execution_count": 6,
"outputs": [
{
"data": {
"text/html": "<div>\n<style scoped>\n .dataframe tbody tr th:only-of-type {\n vertical-align: middle;\n }\n\n .dataframe tbody tr th {\n vertical-align: top;\n }\n\n .dataframe thead th {\n text-align: right;\n }\n</style>\n<table border=\"1\" class=\"dataframe\">\n <thead>\n <tr style=\"text-align: right;\">\n <th></th>\n <th>user_id</th>\n <th>movie_id</th>\n <th>rating</th>\n <th>unix_timestamp</th>\n </tr>\n </thead>\n <tbody>\n <tr>\n <th>0</th>\n <td>1</td>\n <td>1193</td>\n <td>5</td>\n <td>978300760</td>\n </tr>\n <tr>\n <th>1</th>\n <td>1</td>\n <td>661</td>\n <td>3</td>\n <td>978302109</td>\n </tr>\n <tr>\n <th>2</th>\n <td>1</td>\n <td>914</td>\n <td>3</td>\n <td>978301968</td>\n </tr>\n <tr>\n <th>3</th>\n <td>1</td>\n <td>3408</td>\n <td>4</td>\n <td>978300275</td>\n </tr>\n <tr>\n <th>4</th>\n <td>1</td>\n <td>2355</td>\n <td>5</td>\n <td>978824291</td>\n </tr>\n <tr>\n <th>...</th>\n <td>...</td>\n <td>...</td>\n <td>...</td>\n <td>...</td>\n </tr>\n <tr>\n <th>1000204</th>\n <td>6040</td>\n <td>1091</td>\n <td>1</td>\n <td>956716541</td>\n </tr>\n <tr>\n <th>1000205</th>\n <td>6040</td>\n <td>1094</td>\n <td>5</td>\n <td>956704887</td>\n </tr>\n <tr>\n <th>1000206</th>\n <td>6040</td>\n <td>562</td>\n <td>5</td>\n <td>956704746</td>\n </tr>\n <tr>\n <th>1000207</th>\n <td>6040</td>\n <td>1096</td>\n <td>4</td>\n <td>956715648</td>\n </tr>\n <tr>\n <th>1000208</th>\n <td>6040</td>\n <td>1097</td>\n <td>4</td>\n <td>956715569</td>\n </tr>\n </tbody>\n</table>\n<p>1000209 rows × 4 columns</p>\n</div>",
"text/plain": " user_id movie_id rating unix_timestamp\n0 1 1193 5 978300760\n1 1 661 3 978302109\n2 1 914 3 978301968\n3 1 3408 4 978300275\n4 1 2355 5 978824291\n... ... ... ... ...\n1000204 6040 1091 1 956716541\n1000205 6040 1094 5 956704887\n1000206 6040 562 5 956704746\n1000207 6040 1096 4 956715648\n1000208 6040 1097 4 956715569\n\n[1000209 rows x 4 columns]"
},
"execution_count": 6,
"metadata": {},
"output_type": "execute_result"
}
]
},
{
"metadata": {},
"cell_type": "markdown",
"source": "Let's combine some of this information into a single DataFrame, to make it easier for us to work with."
},
{
"metadata": {
"trusted": false
},
"cell_type": "code",
"source": "ratings_df = pd.merge(ratings, movies)[['user_id', 'title', 'rating', 'unix_timestamp']]",
"execution_count": 7,
"outputs": []
},
{
"metadata": {
"trusted": false
},
"cell_type": "code",
"source": "ratings_df[\"user_id\"] = ratings_df[\"user_id\"].astype(str)",
"execution_count": 8,
"outputs": []
},
{
"metadata": {},
"cell_type": "markdown",
"source": "Using pandas, we can print some high-level statistics about the dataset, which may be useful to us."
},
{
"metadata": {
"trusted": false
},
"cell_type": "code",
"source": "ratings_per_user = ratings_df.groupby('user_id').rating.count()\nratings_per_item = ratings_df.groupby('title').rating.count()\n\nprint(f\"Total No. of users: {len(ratings_df.user_id.unique())}\")\nprint(f\"Total No. of items: {len(ratings_df.title.unique())}\")\nprint(\"\\n\")\n\nprint(f\"Max observed rating: {ratings_df.rating.max()}\")\nprint(f\"Min observed rating: {ratings_df.rating.min()}\")\nprint(\"\\n\")\n\nprint(f\"Max no. of user ratings: {ratings_per_user.max()}\")\nprint(f\"Min no. of user ratings: {ratings_per_user.min()}\")\nprint(f\"Median no. of ratings per user: {ratings_per_user.median()}\")\nprint(\"\\n\")\n\nprint(f\"Max no. of item ratings: {ratings_per_item.max()}\")\nprint(f\"Min no. of item ratings: {ratings_per_item.min()}\")\nprint(f\"Median no. of ratings per item: {ratings_per_item.median()}\")\n",
"execution_count": 9,
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": "Total No. of users: 6040\nTotal No. of items: 3706\n\n\nMax observed rating: 5\nMin observed rating: 1\n\n\nMax no. of user ratings: 2314\nMin no. of user ratings: 20\nMedian no. of ratings per user: 96.0\n\n\nMax no. of item ratings: 3428\nMin no. of item ratings: 1\nMedian no. of ratings per item: 123.5\n"
}
]
},
{
"metadata": {},
"cell_type": "markdown",
"source": "From this, we can see that all ratings are between 1 and 5 and every item has been rated at least once. As every user has rated at least 20 movies, we don't have to worry about the case of how to recommend items to a user where we know nothing about their preferences - but this is often not the case in the real world!"
},
{
"metadata": {},
"cell_type": "markdown",
"source": "### Splitting into training and validation sets"
},
{
"metadata": {},
"cell_type": "markdown",
"source": "Before we start modeling, we need to split this dataset into training and validations sets. Often, splitting the dataset is done by randomly sampling a selection of rows, which is a good approach in some cases. However, as we intend to train a transformer model on sequences of ratings, this approach will not work for our purposes. This is because, if we were to simply remove a set of random rows, this is not a good representation of the task that we are trying to model; as it is likely that, for some users, ratings from the middle of a sequence will end up in the validation set.\n\nTo avoid this, one approach would be to use a strategy known as 'leave-one-out' validation, in which we select the last chronological rating for each user, given that they have rated some number of items greater than a defined threshold. As this is a good representation of the approach we are trying to model, this is the approach we shall use here.\n\nLet's define a function to get the last n for each user"
},
{
"metadata": {
"trusted": false
},
"cell_type": "code",
"source": "def get_last_n_ratings_by_user(\n df, n, min_ratings_per_user=1, user_colname=\"user_id\", timestamp_colname=\"unix_timestamp\"\n):\n return (\n df.groupby(user_colname)\n .filter(lambda x: len(x) >= min_ratings_per_user)\n .sort_values(timestamp_colname)\n .groupby(user_colname)\n .tail(n)\n .sort_values(user_colname)\n )",
"execution_count": 10,
"outputs": []
},
{
"metadata": {
"trusted": false
},
"cell_type": "code",
"source": "get_last_n_ratings_by_user(ratings_df, 1)",
"execution_count": 11,
"outputs": [
{
"data": {
"text/html": "<div>\n<style scoped>\n .dataframe tbody tr th:only-of-type {\n vertical-align: middle;\n }\n\n .dataframe tbody tr th {\n vertical-align: top;\n }\n\n .dataframe thead th {\n text-align: right;\n }\n</style>\n<table border=\"1\" class=\"dataframe\">\n <thead>\n <tr style=\"text-align: right;\">\n <th></th>\n <th>user_id</th>\n <th>title</th>\n <th>rating</th>\n <th>unix_timestamp</th>\n </tr>\n </thead>\n <tbody>\n <tr>\n <th>28501</th>\n <td>1</td>\n <td>Pocahontas (1995)</td>\n <td>5</td>\n <td>978824351</td>\n </tr>\n <tr>\n <th>482398</th>\n <td>10</td>\n <td>Hero (1992)</td>\n <td>5</td>\n <td>980638688</td>\n </tr>\n <tr>\n <th>800008</th>\n <td>100</td>\n <td>Apocalypse Now (1979)</td>\n <td>2</td>\n <td>977594963</td>\n </tr>\n <tr>\n <th>496041</th>\n <td>1000</td>\n <td>Streetcar Named Desire, A (1951)</td>\n <td>5</td>\n <td>975042421</td>\n </tr>\n <tr>\n <th>305563</th>\n <td>1001</td>\n <td>Austin Powers: The Spy Who Shagged Me (1999)</td>\n <td>2</td>\n <td>1028605534</td>\n </tr>\n <tr>\n <th>...</th>\n <td>...</td>\n <td>...</td>\n <td>...</td>\n <td>...</td>\n </tr>\n <tr>\n <th>767773</th>\n <td>995</td>\n <td>French Kiss (1995)</td>\n <td>3</td>\n <td>975099776</td>\n </tr>\n <tr>\n <th>573889</th>\n <td>996</td>\n <td>Almost Famous (2000)</td>\n <td>5</td>\n <td>1001227064</td>\n </tr>\n <tr>\n <th>76463</th>\n <td>997</td>\n <td>Gladiator (2000)</td>\n <td>4</td>\n <td>978915132</td>\n </tr>\n <tr>\n <th>998801</th>\n <td>998</td>\n <td>See the Sea (Regarde la mer) (1997)</td>\n <td>5</td>\n <td>975192573</td>\n </tr>\n <tr>\n <th>784436</th>\n <td>999</td>\n <td>Free Willy (1993)</td>\n <td>2</td>\n <td>975364891</td>\n </tr>\n </tbody>\n</table>\n<p>6040 rows × 4 columns</p>\n</div>",
"text/plain": " user_id title rating \\\n28501 1 Pocahontas (1995) 5 \n482398 10 Hero (1992) 5 \n800008 100 Apocalypse Now (1979) 2 \n496041 1000 Streetcar Named Desire, A (1951) 5 \n305563 1001 Austin Powers: The Spy Who Shagged Me (1999) 2 \n... ... ... ... \n767773 995 French Kiss (1995) 3 \n573889 996 Almost Famous (2000) 5 \n76463 997 Gladiator (2000) 4 \n998801 998 See the Sea (Regarde la mer) (1997) 5 \n784436 999 Free Willy (1993) 2 \n\n unix_timestamp \n28501 978824351 \n482398 980638688 \n800008 977594963 \n496041 975042421 \n305563 1028605534 \n... ... \n767773 975099776 \n573889 1001227064 \n76463 978915132 \n998801 975192573 \n784436 975364891 \n\n[6040 rows x 4 columns]"
},
"execution_count": 11,
"metadata": {},
"output_type": "execute_result"
}
]
},
{
"metadata": {},
"cell_type": "markdown",
"source": "We can now use this to define another function to mark the last n ratings per user as our validation set; representing this using the is_valid column:"
},
{
"metadata": {
"trusted": false
},
"cell_type": "code",
"source": "def mark_last_n_ratings_as_validation_set(\n df, n, min_ratings=1, user_colname=\"user_id\", timestamp_colname=\"unix_timestamp\"\n):\n \"\"\"\n Mark the chronologically last n ratings as the validation set.\n This is done by adding the additional 'is_valid' column to the df.\n :param df: a DataFrame containing user item ratings\n :param n: the number of ratings to include in the validation set\n :param min_ratings: only include users with more than this many ratings\n :param user_id_colname: the name of the column containing user ids\n :param timestamp_colname: the name of the column containing the imestamps\n :return: the same df with the additional 'is_valid' column added\n \"\"\"\n df[\"is_valid\"] = False\n df.loc[\n get_last_n_ratings_by_user(\n df,\n n,\n min_ratings,\n user_colname=user_colname,\n timestamp_colname=timestamp_colname,\n ).index,\n \"is_valid\",\n ] = True\n\n return df",
"execution_count": 12,
"outputs": []
},
{
"metadata": {},
"cell_type": "markdown",
"source": "Applying this to our DataFrame, we can see that we now have a validation set of 6040 rows - one for each user."
},
{
"metadata": {
"trusted": false
},
"cell_type": "code",
"source": "mark_last_n_ratings_as_validation_set(ratings_df, 1)",
"execution_count": 13,
"outputs": [
{
"data": {
"text/html": "<div>\n<style scoped>\n .dataframe tbody tr th:only-of-type {\n vertical-align: middle;\n }\n\n .dataframe tbody tr th {\n vertical-align: top;\n }\n\n .dataframe thead th {\n text-align: right;\n }\n</style>\n<table border=\"1\" class=\"dataframe\">\n <thead>\n <tr style=\"text-align: right;\">\n <th></th>\n <th>user_id</th>\n <th>title</th>\n <th>rating</th>\n <th>unix_timestamp</th>\n <th>is_valid</th>\n </tr>\n </thead>\n <tbody>\n <tr>\n <th>0</th>\n <td>1</td>\n <td>One Flew Over the Cuckoo's Nest (1975)</td>\n <td>5</td>\n <td>978300760</td>\n <td>False</td>\n </tr>\n <tr>\n <th>1</th>\n <td>2</td>\n <td>One Flew Over the Cuckoo's Nest (1975)</td>\n <td>5</td>\n <td>978298413</td>\n <td>False</td>\n </tr>\n <tr>\n <th>2</th>\n <td>12</td>\n <td>One Flew Over the Cuckoo's Nest (1975)</td>\n <td>4</td>\n <td>978220179</td>\n <td>False</td>\n </tr>\n <tr>\n <th>3</th>\n <td>15</td>\n <td>One Flew Over the Cuckoo's Nest (1975)</td>\n <td>4</td>\n <td>978199279</td>\n <td>False</td>\n </tr>\n <tr>\n <th>4</th>\n <td>17</td>\n <td>One Flew Over the Cuckoo's Nest (1975)</td>\n <td>5</td>\n <td>978158471</td>\n <td>False</td>\n </tr>\n <tr>\n <th>...</th>\n <td>...</td>\n <td>...</td>\n <td>...</td>\n <td>...</td>\n <td>...</td>\n </tr>\n <tr>\n <th>1000204</th>\n <td>5949</td>\n <td>Modulations (1998)</td>\n <td>5</td>\n <td>958846401</td>\n <td>False</td>\n </tr>\n <tr>\n <th>1000205</th>\n <td>5675</td>\n <td>Broken Vessels (1998)</td>\n <td>3</td>\n <td>976029116</td>\n <td>False</td>\n </tr>\n <tr>\n <th>1000206</th>\n <td>5780</td>\n <td>White Boys (1999)</td>\n <td>1</td>\n <td>958153068</td>\n <td>False</td>\n </tr>\n <tr>\n <th>1000207</th>\n <td>5851</td>\n <td>One Little Indian (1973)</td>\n <td>5</td>\n <td>957756608</td>\n <td>False</td>\n </tr>\n <tr>\n <th>1000208</th>\n <td>5938</td>\n <td>Five Wives, Three Secretaries and Me (1998)</td>\n <td>4</td>\n <td>957273353</td>\n <td>False</td>\n </tr>\n </tbody>\n</table>\n<p>1000209 rows × 5 columns</p>\n</div>",
"text/plain": " user_id title rating \\\n0 1 One Flew Over the Cuckoo's Nest (1975) 5 \n1 2 One Flew Over the Cuckoo's Nest (1975) 5 \n2 12 One Flew Over the Cuckoo's Nest (1975) 4 \n3 15 One Flew Over the Cuckoo's Nest (1975) 4 \n4 17 One Flew Over the Cuckoo's Nest (1975) 5 \n... ... ... ... \n1000204 5949 Modulations (1998) 5 \n1000205 5675 Broken Vessels (1998) 3 \n1000206 5780 White Boys (1999) 1 \n1000207 5851 One Little Indian (1973) 5 \n1000208 5938 Five Wives, Three Secretaries and Me (1998) 4 \n\n unix_timestamp is_valid \n0 978300760 False \n1 978298413 False \n2 978220179 False \n3 978199279 False \n4 978158471 False \n... ... ... \n1000204 958846401 False \n1000205 976029116 False \n1000206 958153068 False \n1000207 957756608 False \n1000208 957273353 False \n\n[1000209 rows x 5 columns]"
},
"execution_count": 13,
"metadata": {},
"output_type": "execute_result"
}
]
},
{
"metadata": {
"trusted": false
},
"cell_type": "code",
"source": "train_df = ratings_df[ratings_df.is_valid==False]\nvalid_df = ratings_df[ratings_df.is_valid==True]",
"execution_count": 14,
"outputs": []
},
{
"metadata": {
"trusted": false
},
"cell_type": "code",
"source": "len(valid_df)",
"execution_count": 15,
"outputs": [
{
"data": {
"text/plain": "6040"
},
"execution_count": 15,
"metadata": {},
"output_type": "execute_result"
}
]
},
{
"metadata": {},
"cell_type": "markdown",
"source": "Even when considering model benchmarks on the same dataset, to have a fair comparison, it is important to understand how the data has been split and to make sure that the approaches taken are consistent!"
},
{
"metadata": {},
"cell_type": "markdown",
"source": "## Creating a Baseline Model"
},
{
"metadata": {},
"cell_type": "markdown",
"source": "When starting a new modeling task, it is often a good idea to create a very simple model - known as a baseline model - to perform the task in a straightforward way that requires minimal effort to implement. We can then use the metrics from this model as a comparison for all future approaches; if a complex model is getting worse results than the baseline model, this is a bad sign!\n\nHere, an approach that we can use for this is to simply predict the average rating for every movie, irrespective of context. As the mean can be heavily affected by outliers, let's use the median for this. We can easily calculate the median rating from our training set as follows:"
},
{
"metadata": {
"trusted": false
},
"cell_type": "code",
"source": "median_rating = train_df.rating.median(); median_rating",
"execution_count": 17,
"outputs": [
{
"data": {
"text/plain": "4.0"
},
"execution_count": 17,
"metadata": {},
"output_type": "execute_result"
}
]
},
{
"metadata": {},
"cell_type": "markdown",
"source": "We can then use this as the prediction for every rating in the validation set and calculate our metrics:"
},
{
"metadata": {
"trusted": false
},
"cell_type": "code",
"source": "import math\nfrom sklearn.metrics import mean_squared_error, mean_absolute_error\n\npredictions = np.array([median_rating]* len(valid_df))\n\nmae = mean_absolute_error(valid_df.rating, predictions)\nmse = mean_squared_error(valid_df.rating, predictions)\nrmse = math.sqrt(mse)\n\nprint(f'mae: {mae}')\nprint(f'mse: {mse}')\nprint(f'rmse: {rmse}')",
"execution_count": 18,
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": "mae: 0.91158940397351\nmse: 1.5304635761589405\nrmse: 1.2371190630488806\n"
}
]
},
{
"metadata": {},
"cell_type": "markdown",
"source": "## Matrix factorization with bias"
},
{
"metadata": {},
"cell_type": "markdown",
"source": "One very popular approach toward recommendations, both in academia and industry, is matrix factorization.\n\nIn addition to representing recommendations in a table, such as our DataFrame, an alternative view would be to represent a set of user-item ratings as a matrix. We can visualize this on a sample of our data as presented below:"
},
{
"metadata": {
"trusted": false
},
"cell_type": "code",
"source": "ratings_df[((ratings_df.user_id == '1') | \n (ratings_df.user_id == '2')| \n (ratings_df.user_id == '4')) \n & ((ratings_df.title == \"One Flew Over the Cuckoo's Nest (1975)\") | \n (ratings_df.title == \"To Kill a Mockingbird (1962)\")| \n (ratings_df.title == \"Saving Private Ryan (1998)\"))].pivot_table('rating', index='user_id', columns='title').fillna('?')",
"execution_count": 19,
"outputs": [
{
"data": {
"text/html": "<div>\n<style scoped>\n .dataframe tbody tr th:only-of-type {\n vertical-align: middle;\n }\n\n .dataframe tbody tr th {\n vertical-align: top;\n }\n\n .dataframe thead th {\n text-align: right;\n }\n</style>\n<table border=\"1\" class=\"dataframe\">\n <thead>\n <tr style=\"text-align: right;\">\n <th>title</th>\n <th>One Flew Over the Cuckoo's Nest (1975)</th>\n <th>Saving Private Ryan (1998)</th>\n <th>To Kill a Mockingbird (1962)</th>\n </tr>\n <tr>\n <th>user_id</th>\n <th></th>\n <th></th>\n <th></th>\n </tr>\n </thead>\n <tbody>\n <tr>\n <th>1</th>\n <td>5.0</td>\n <td>5.0</td>\n <td>4.0</td>\n </tr>\n <tr>\n <th>2</th>\n <td>5.0</td>\n <td>4.0</td>\n <td>4.0</td>\n </tr>\n <tr>\n <th>4</th>\n <td>?</td>\n <td>5.0</td>\n <td>?</td>\n </tr>\n </tbody>\n</table>\n</div>",
"text/plain": "title One Flew Over the Cuckoo's Nest (1975) Saving Private Ryan (1998) \\\nuser_id \n1 5.0 5.0 \n2 5.0 4.0 \n4 ? 5.0 \n\ntitle To Kill a Mockingbird (1962) \nuser_id \n1 4.0 \n2 4.0 \n4 ? "
},
"execution_count": 19,
"metadata": {},
"output_type": "execute_result"
}
]
},
{
"metadata": {},
"cell_type": "markdown",
"source": "As not every user will have rated every movie, we can see that some values are missing. Therefore, we can formulate our recommendation problem in the following way:\n\nHow can we fill in the blanks, such that the values are consistent with the existing ratings in the matrix?\n\nOne way that we can approach this is by considering that there are two smaller matrices that can be multiplied together to make our ratings matrix."
},
{
"metadata": {},
"cell_type": "markdown",
"source": "Before we think about training a model, we first need to get the data into the correct format. Currently, we have a title that represents each movie, which is a string; we need to convert this to an integer format so that we can feed it into the model. While we already have an ID representing each user, let's also create our own encoding for this. I generally find it good practice to control all the encodings related to a training process, rather than relying on predefined ID systems defined elsewhere; you will be surprised how many IDs that are supposed to be immutable and unique turn out to be otherwise in the real world!\n\nHere, we can do this very simply by enumerating every unique value for both users and movies. We can create lookup tables for this as shown below:"
},
{
"metadata": {
"trusted": false
},
"cell_type": "code",
"source": "user_lookup = {v: i+1 for i, v in enumerate(ratings_df['user_id'].unique())}",
"execution_count": 20,
"outputs": []
},
{
"metadata": {
"trusted": false
},
"cell_type": "code",
"source": "movie_lookup = {v: i+1 for i, v in enumerate(ratings_df['title'].unique())}",
"execution_count": 21,
"outputs": []
},
{
"metadata": {},
"cell_type": "markdown",
"source": "Now that we can encode our features, as we are using PyTorch, we need to define a Dataset to wrap our DataFrame and return the user-item ratings."
},
{
"metadata": {
"trusted": false
},
"cell_type": "code",
"source": "from torch.utils.data import Dataset\n\nclass UserItemRatingDataset(Dataset):\n def __init__(self, df, movie_lookup, user_lookup):\n self.df = df\n self.movie_lookup = movie_lookup\n self.user_lookup = user_lookup\n\n def __getitem__(self, index):\n row = self.df.iloc[index]\n user_id = self.user_lookup[row.user_id]\n movie_id = self.movie_lookup[row.title]\n \n rating = torch.tensor(row.rating, dtype=torch.float32)\n \n return (user_id, movie_id), rating\n\n def __len__(self):\n return len(self.df)\n",
"execution_count": 22,
"outputs": []
},
{
"metadata": {},
"cell_type": "markdown",
"source": "We can now use this to create our training and validation datasets:"
},
{
"metadata": {
"trusted": false
},
"cell_type": "code",
"source": "train_dataset = UserItemRatingDataset(train_df, movie_lookup, user_lookup)\nvalid_dataset = UserItemRatingDataset(valid_df, movie_lookup, user_lookup)",
"execution_count": 23,
"outputs": []
},
{
"metadata": {},
"cell_type": "markdown",
"source": "Next, let's define the model."
},
{
"metadata": {
"trusted": false
},
"cell_type": "code",
"source": "import torch\nfrom torch import nn\n\nclass MfDotBias(nn.Module):\n\n def __init__(\n self, n_factors, n_users, n_items, ratings_range=None, use_biases=True\n ):\n super().__init__()\n self.bias = use_biases\n self.y_range = ratings_range\n self.user_embedding = nn.Embedding(n_users+1, n_factors, padding_idx=0)\n self.item_embedding = nn.Embedding(n_items+1, n_factors, padding_idx=0)\n\n if use_biases:\n self.user_bias = nn.Embedding(n_users+1, 1, padding_idx=0)\n self.item_bias = nn.Embedding(n_items+1, 1, padding_idx=0)\n\n def forward(self, inputs):\n users, items = inputs\n dot = self.user_embedding(users) * self.item_embedding(items)\n result = dot.sum(1)\n if self.bias:\n result = (\n result + self.user_bias(users).squeeze() + self.item_bias(items).squeeze()\n )\n\n if self.y_range is None:\n return result\n else:\n return (\n torch.sigmoid(result) * (self.y_range[1] - self.y_range[0])\n + self.y_range[0]\n )",
"execution_count": 24,
"outputs": []
},
{
"metadata": {},
"cell_type": "markdown",
"source": "As we can see, this is very simple to define. Note that because an embedding layer is simply a lookup table, it is important that when we specify the size of the embedding layer, it must contain any value that will be seen during training and evaluation. Because of this, we will use the number of unique items observed in the full dataset to do this, not just the training set. We have also specified a padding embedding at index 0, which can be used for any unknown values. PyTorch handles this by setting this entry to a zero-vector, which is not updated during training.\n\nAdditionally, as this is a regression task, the range that the model could predict is potentially unbounded. While the model can learn to restrict the output values to between 1 and 5, we can make this easier for the model by modifying the architecture to restrict this range prior to training. We have done this by applying the sigmoid function to the model's output - which restricts the range to between 0 and 1 - and then scaling this to within a range that we can define."
},
{
"metadata": {},
"cell_type": "markdown",
"source": "### Train with PyTorch accelerated"
},
{
"metadata": {},
"cell_type": "markdown",
"source": "At this point, we would usually start writing the training loop; however, as we are using pytorch-accelerated, this will largely be taken care of for us. However, as pytorch-accelerated tracks only the training and validation losses by default, let's create a callback to track our metrics."
},
{
"metadata": {
"trusted": false
},
"cell_type": "code",
"source": "from functools import partial\n\nfrom pytorch_accelerated import Trainer, notebook_launcher \nfrom pytorch_accelerated.trainer import TrainerPlaceholderValues, DEFAULT_CALLBACKS\nfrom pytorch_accelerated.callbacks import EarlyStoppingCallback, SaveBestModelCallback, TrainerCallback, StopTrainingError\nimport torchmetrics",
"execution_count": 25,
"outputs": []
},
{
"metadata": {},
"cell_type": "markdown",
"source": "Let's create a callback to track our metrics"
},
{
"metadata": {
"trusted": false
},
"cell_type": "code",
"source": "class RecommenderMetricsCallback(TrainerCallback):\n def __init__(self):\n self.metrics = torchmetrics.MetricCollection(\n {\n \"mse\": torchmetrics.MeanSquaredError(),\n \"mae\": torchmetrics.MeanAbsoluteError(),\n }\n )\n\n def _move_to_device(self, trainer):\n self.metrics.to(trainer.device)\n\n def on_training_run_start(self, trainer, **kwargs):\n self._move_to_device(trainer)\n\n def on_evaluation_run_start(self, trainer, **kwargs):\n self._move_to_device(trainer)\n\n def on_eval_step_end(self, trainer, batch, batch_output, **kwargs):\n preds = batch_output[\"model_outputs\"]\n self.metrics.update(preds, batch[1])\n\n def on_eval_epoch_end(self, trainer, **kwargs):\n metrics = self.metrics.compute()\n \n mse = metrics[\"mse\"].cpu()\n trainer.run_history.update_metric(\"mae\", metrics[\"mae\"].cpu())\n trainer.run_history.update_metric(\"mse\", mse)\n trainer.run_history.update_metric(\"rmse\", math.sqrt(mse))\n\n self.metrics.reset()",
"execution_count": 26,
"outputs": []
},
{
"metadata": {},
"cell_type": "markdown",
"source": "Now, all that is left to do is to train the model. PyTorch-accelerated provides a notebook_launcher function, which enables us to run multi-GPU training runs from within a notebook. To use this, all we need to do is to define a training function that instantiates our Trainer object and calls the train method.\n\nComponents such as the model and dataset can be defined anywhere in the notebook, but it is important that the trainer is only ever instantiated within a training function."
},
{
"metadata": {
"trusted": false
},
"cell_type": "code",
"source": "def train_mf_model():\n model = MfDotBias(\n 120, len(user_lookup), len(movie_lookup), ratings_range=[0.5, 5.5]\n )\n loss_func = torch.nn.MSELoss()\n\n optimizer = torch.optim.AdamW(model.parameters(), lr=0.01)\n\n create_sched_fn = partial(\n torch.optim.lr_scheduler.OneCycleLR,\n max_lr=0.01,\n epochs=TrainerPlaceholderValues.NUM_EPOCHS,\n steps_per_epoch=TrainerPlaceholderValues.NUM_UPDATE_STEPS_PER_EPOCH,\n )\n\n trainer = Trainer(\n model=model,\n loss_func=loss_func,\n optimizer=optimizer,\n callbacks=(\n RecommenderMetricsCallback,\n *DEFAULT_CALLBACKS,\n SaveBestModelCallback(watch_metric=\"mae\"),\n EarlyStoppingCallback(\n early_stopping_patience=2,\n early_stopping_threshold=0.001,\n watch_metric=\"mae\",\n ),\n ),\n )\n\n trainer.train(\n train_dataset=train_dataset,\n eval_dataset=valid_dataset,\n num_epochs=30,\n per_device_batch_size=512,\n create_scheduler_fn=create_sched_fn,\n )\n",
"execution_count": 27,
"outputs": []
},
{
"metadata": {
"scrolled": true,
"trusted": false
},
"cell_type": "code",
"source": "notebook_launcher(train_mf_model, num_processes=2)",
"execution_count": 79,
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": "Launching a training on 2 GPUs.\n\nStarting training run\n\nStarting epoch 1\n"
},
{
"name": "stderr",
"output_type": "stream",
"text": "100%|██████████| 971/971 [00:22<00:00, 43.97it/s]\n"
},
{
"name": "stdout",
"output_type": "stream",
"text": "\ntrain_loss_epoch: 6.917553340860793\n"
},
{
"name": "stderr",
"output_type": "stream",
"text": "100%|██████████| 6/6 [00:00<00:00, 15.46it/s]\n"
},
{
"name": "stdout",
"output_type": "stream",
"text": "\nmae: 2.295109510421753\n\neval_loss_epoch: 7.015153566996257\n\nrmse: 2.6486134878555867\n\nmse: 7.015153408050537\n\nStarting epoch 2\n"
},
{
"name": "stderr",
"output_type": "stream",
"text": "100%|██████████| 971/971 [00:22<00:00, 43.85it/s]\n"
},
{
"name": "stdout",
"output_type": "stream",
"text": "\ntrain_loss_epoch: 6.613832648087726\n"
},
{
"name": "stderr",
"output_type": "stream",
"text": "100%|██████████| 6/6 [00:00<00:00, 15.00it/s]\n"
},
{
"name": "stdout",
"output_type": "stream",
"text": "\nmae: 2.2753686904907227\n\neval_loss_epoch: 6.9166419506073\n\nrmse: 2.629950895394954\n\nmse: 6.916641712188721\n\nImprovement of 0.019740819931030273 observed, resetting counter. \nEarly stopping counter: 0/2\n\nStarting epoch 3\n"
},
{
"name": "stderr",
"output_type": "stream",
"text": "100%|██████████| 971/971 [00:22<00:00, 42.62it/s]\n"
},
{
"name": "stdout",
"output_type": "stream",
"text": "\ntrain_loss_epoch: 6.165679518643663\n"
},
{
"name": "stderr",
"output_type": "stream",
"text": "100%|██████████| 6/6 [00:00<00:00, 14.32it/s]\n"
},
{
"name": "stdout",
"output_type": "stream",
"text": "\nmae: 2.237168550491333\n\neval_loss_epoch: 6.745323101679484\n\nrmse: 2.59717600118905\n\nmse: 6.745323181152344\n\nImprovement of 0.03820013999938965 observed, resetting counter. \nEarly stopping counter: 0/2\n\nStarting epoch 4\n"
},
{
"name": "stderr",
"output_type": "stream",
"text": "100%|██████████| 971/971 [00:23<00:00, 42.14it/s]\n"
},
{
"name": "stdout",
"output_type": "stream",
"text": "\ntrain_loss_epoch: 5.713309306685883\n"
},
{
"name": "stderr",
"output_type": "stream",
"text": "100%|██████████| 6/6 [00:00<00:00, 14.77it/s]\n"
},
{
"name": "stdout",
"output_type": "stream",
"text": "\nmae: 2.1757984161376953\n\neval_loss_epoch: 6.463704665501912\n\nrmse: 2.5423816759151356\n\nmse: 6.463704586029053\n\nImprovement of 0.061370134353637695 observed, resetting counter. \nEarly stopping counter: 0/2\n\nStarting epoch 5\n"
},
{
"name": "stderr",
"output_type": "stream",
"text": "100%|██████████| 971/971 [00:22<00:00, 43.71it/s]\n"
},
{
"name": "stdout",
"output_type": "stream",
"text": "\ntrain_loss_epoch: 5.240608819849582\n"
},
{
"name": "stderr",
"output_type": "stream",
"text": "100%|██████████| 6/6 [00:00<00:00, 15.74it/s]\n"
},
{
"name": "stdout",
"output_type": "stream",
"text": "\nmae: 2.0645127296447754\n\neval_loss_epoch: 5.93857479095459\n\nrmse: 2.4369190208370552\n\nmse: 5.938574314117432\n\nImprovement of 0.11128568649291992 observed, resetting counter. \nEarly stopping counter: 0/2\n\nStarting epoch 6\n"
},
{
"name": "stderr",
"output_type": "stream",
"text": "100%|██████████| 971/971 [00:22<00:00, 42.76it/s]\n"
},
{
"name": "stdout",
"output_type": "stream",
"text": "\ntrain_loss_epoch: 4.5496183526994765\n"
},
{
"name": "stderr",
"output_type": "stream",
"text": "100%|██████████| 6/6 [00:00<00:00, 15.47it/s]\n"
},
{
"name": "stdout",
"output_type": "stream",
"text": "\nmae: 1.8991206884384155\n\neval_loss_epoch: 5.143191337585449\n\nrmse: 2.2678606249993862\n\nmse: 5.143191814422607\n\nImprovement of 0.16539204120635986 observed, resetting counter. \nEarly stopping counter: 0/2\n\nStarting epoch 7\n"
},
{
"name": "stderr",
"output_type": "stream",
"text": "100%|██████████| 971/971 [00:22<00:00, 42.67it/s]\n"
},
{
"name": "stdout",
"output_type": "stream",
"text": "\ntrain_loss_epoch: 3.7405364967644767\n"
},
{
"name": "stderr",
"output_type": "stream",
"text": "100%|██████████| 6/6 [00:00<00:00, 14.57it/s]\n"
},
{
"name": "stdout",
"output_type": "stream",
"text": "\nmae: 1.739580750465393\n\neval_loss_epoch: 4.3907707532246905\n\nrmse: 2.095416644052063\n\nmse: 4.39077091217041\n\nImprovement of 0.15953993797302246 observed, resetting counter. \nEarly stopping counter: 0/2\n\nStarting epoch 8\n"
},
{
"name": "stderr",
"output_type": "stream",
"text": "100%|██████████| 971/971 [00:22<00:00, 42.26it/s]\n"
},
{
"name": "stdout",
"output_type": "stream",
"text": "\ntrain_loss_epoch: 3.0049563406915794\n"
},
{
"name": "stderr",
"output_type": "stream",
"text": "100%|██████████| 6/6 [00:00<00:00, 14.61it/s]\n"
},
{
"name": "stdout",
"output_type": "stream",
"text": "\nmae: 1.5949243307113647\n\neval_loss_epoch: 3.7915724913279214\n\nrmse: 1.9471960791868859\n\nmse: 3.7915725708007812\n\nImprovement of 0.14465641975402832 observed, resetting counter. \nEarly stopping counter: 0/2\n\nStarting epoch 9\n"
},
{
"name": "stderr",
"output_type": "stream",
"text": "100%|██████████| 971/971 [00:22<00:00, 43.52it/s]\n"
},
{
"name": "stdout",
"output_type": "stream",
"text": "\ntrain_loss_epoch: 2.226169708828479\n"
},
{
"name": "stderr",
"output_type": "stream",
"text": "100%|██████████| 6/6 [00:00<00:00, 15.21it/s]\n"
},
{
"name": "stdout",
"output_type": "stream",
"text": "\nmae: 1.383549690246582\n\neval_loss_epoch: 3.02661395072937\n\nrmse: 1.7397166294340496\n\nmse: 3.02661395072937\n\nImprovement of 0.21137464046478271 observed, resetting counter. \nEarly stopping counter: 0/2\n\nStarting epoch 10\n"
},
{
"name": "stderr",
"output_type": "stream",
"text": "100%|██████████| 971/971 [00:22<00:00, 43.40it/s]\n"
},
{
"name": "stdout",
"output_type": "stream",
"text": "\ntrain_loss_epoch: 1.4580235015965393\n"
},
{
"name": "stderr",
"output_type": "stream",
"text": "100%|██████████| 6/6 [00:00<00:00, 15.30it/s]\n"
},
{
"name": "stdout",
"output_type": "stream",
"text": "\nmae: 1.18148672580719\n\neval_loss_epoch: 2.3194796641667685\n\nrmse: 1.5229838160345626\n\nmse: 2.3194797039031982\n\nImprovement of 0.2020629644393921 observed, resetting counter. \nEarly stopping counter: 0/2\n\nStarting epoch 11\n"
},
{
"name": "stderr",
"output_type": "stream",
"text": "100%|██████████| 971/971 [00:22<00:00, 43.93it/s]\n"
},
{
"name": "stdout",
"output_type": "stream",
"text": "\ntrain_loss_epoch: 0.9879166315070879\n"
},
{
"name": "stderr",
"output_type": "stream",
"text": "100%|██████████| 6/6 [00:00<00:00, 13.92it/s]\n"
},
{
"name": "stdout",
"output_type": "stream",
"text": "\nmae: 1.0579732656478882\n\neval_loss_epoch: 1.898934801419576\n\nrmse: 1.3780184040667658\n\nmse: 1.8989347219467163\n\nImprovement of 0.12351346015930176 observed, resetting counter. \nEarly stopping counter: 0/2\n\nStarting epoch 12\n"
},
{
"name": "stderr",
"output_type": "stream",
"text": "100%|██████████| 971/971 [00:22<00:00, 42.43it/s]\n"
},
{
"name": "stdout",
"output_type": "stream",
"text": "\ntrain_loss_epoch: 0.7626882120583256\n"
},
{
"name": "stderr",
"output_type": "stream",
"text": "100%|██████████| 6/6 [00:00<00:00, 14.72it/s]\n"
},
{
"name": "stdout",
"output_type": "stream",
"text": "\nmae: 0.9977782368659973\n\neval_loss_epoch: 1.6841108997662861\n\nrmse: 1.2977330230472526\n\nmse: 1.6841109991073608\n\nImprovement of 0.06019502878189087 observed, resetting counter. \nEarly stopping counter: 0/2\n\nStarting epoch 13\n"
},
{
"name": "stderr",
"output_type": "stream",
"text": "100%|██████████| 971/971 [00:22<00:00, 43.88it/s]\n"
},
{
"name": "stdout",
"output_type": "stream",
"text": "\ntrain_loss_epoch: 0.6493910940250825\n"
},
{
"name": "stderr",
"output_type": "stream",
"text": "100%|██████████| 6/6 [00:00<00:00, 15.11it/s]\n"
},
{
"name": "stdout",
"output_type": "stream",
"text": "\nmae: 0.9653714299201965\n\neval_loss_epoch: 1.5768212874730427\n\nrmse: 1.2557154642710555\n\nmse: 1.5768213272094727\n\nImprovement of 0.03240680694580078 observed, resetting counter. \nEarly stopping counter: 0/2\n\nStarting epoch 14\n"
},
{
"name": "stderr",
"output_type": "stream",
"text": "100%|██████████| 971/971 [00:22<00:00, 43.17it/s]\n"
},
{
"name": "stdout",
"output_type": "stream",
"text": "\ntrain_loss_epoch: 0.5876042361608126\n"
},
{
"name": "stderr",
"output_type": "stream",
"text": "100%|██████████| 6/6 [00:00<00:00, 15.26it/s]\n"
},
{
"name": "stdout",
"output_type": "stream",
"text": "\nmae: 0.9496553540229797\n\neval_loss_epoch: 1.5230658650398254\n\nrmse: 1.2341255708575487\n\nmse: 1.5230659246444702\n\nImprovement of 0.015716075897216797 observed, resetting counter. \nEarly stopping counter: 0/2\n\nStarting epoch 15\n"
},
{
"name": "stderr",
"output_type": "stream",
"text": "100%|██████████| 971/971 [00:22<00:00, 43.23it/s]\n"
},
{
"name": "stdout",
"output_type": "stream",
"text": "\ntrain_loss_epoch: 0.5365160376478052\n"
},
{
"name": "stderr",
"output_type": "stream",
"text": "100%|██████████| 6/6 [00:00<00:00, 15.23it/s]\n"
},
{
"name": "stdout",
"output_type": "stream",
"text": "\nmae: 0.9332759976387024\n\neval_loss_epoch: 1.4715567429860432\n\nrmse: 1.2130773526503509\n\nmse: 1.4715566635131836\n\nImprovement of 0.016379356384277344 observed, resetting counter. \nEarly stopping counter: 0/2\n\nStarting epoch 16\n"
},
{
"name": "stderr",
"output_type": "stream",
"text": "100%|██████████| 971/971 [00:22<00:00, 42.80it/s]\n"
},
{
"name": "stdout",
"output_type": "stream",
"text": "\ntrain_loss_epoch: 0.48783980191797477\n"
},
{
"name": "stderr",
"output_type": "stream",
"text": "100%|██████████| 6/6 [00:00<00:00, 14.65it/s]\n"
},
{
"name": "stdout",
"output_type": "stream",
"text": "\nmae: 0.9213075637817383\n\neval_loss_epoch: 1.4255497852961223\n\nrmse: 1.193963929425417\n\nmse: 1.425549864768982\n\nImprovement of 0.011968433856964111 observed, resetting counter. \nEarly stopping counter: 0/2\n\nStarting epoch 17\n"
},
{
"name": "stderr",
"output_type": "stream",
"text": "100%|██████████| 971/971 [00:23<00:00, 42.15it/s]\n"
},
{
"name": "stdout",
"output_type": "stream",
"text": "\ntrain_loss_epoch: 0.43791735835227613\n"
},
{
"name": "stderr",
"output_type": "stream",
"text": "100%|██████████| 6/6 [00:00<00:00, 14.22it/s]\n"
},
{
"name": "stdout",
"output_type": "stream",
"text": "\nmae: 0.9033301472663879\n\neval_loss_epoch: 1.3603304823239644\n\nrmse: 1.1663321060765837\n\nmse: 1.360330581665039\n\nImprovement of 0.017977416515350342 observed, resetting counter. \nEarly stopping counter: 0/2\n\nStarting epoch 18\n"
},
{
"name": "stderr",
"output_type": "stream",
"text": "100%|██████████| 971/971 [00:22<00:00, 43.40it/s]\n"
},
{
"name": "stdout",
"output_type": "stream",
"text": "\ntrain_loss_epoch: 0.38761754252638064\n"
},
{
"name": "stderr",
"output_type": "stream",
"text": "100%|██████████| 6/6 [00:00<00:00, 14.99it/s]\n"
},
{
"name": "stdout",
"output_type": "stream",
"text": "\nmae: 0.892224133014679\n\neval_loss_epoch: 1.3282848397890727\n\nrmse: 1.1525123602148473\n\nmse: 1.328284740447998\n\nImprovement of 0.011106014251708984 observed, resetting counter. \nEarly stopping counter: 0/2\n\nStarting epoch 19\n"
},
{
"name": "stderr",
"output_type": "stream",
"text": "100%|██████████| 971/971 [00:22<00:00, 43.21it/s]\n"
},
{
"name": "stdout",
"output_type": "stream",
"text": "\ntrain_loss_epoch: 0.3391927664852044\n"
},
{
"name": "stderr",
"output_type": "stream",
"text": "100%|██████████| 6/6 [00:00<00:00, 14.48it/s]\n"
},
{
"name": "stdout",
"output_type": "stream",
"text": "\nmae: 0.8964769244194031\n\neval_loss_epoch: 1.3362776637077332\n\nrmse: 1.1559747419831838\n\nmse: 1.3362776041030884\nNo improvement above threshold observed, incrementing counter. \nEarly stopping counter: 1/2\n\nStarting epoch 20\n"
},
{
"name": "stderr",
"output_type": "stream",
"text": "100%|██████████| 971/971 [00:22<00:00, 42.99it/s]\n"
},
{
"name": "stdout",
"output_type": "stream",
"text": "\ntrain_loss_epoch: 0.29504243754824944\n"
},
{
"name": "stderr",
"output_type": "stream",
"text": "100%|██████████| 6/6 [00:00<00:00, 15.36it/s]\n"
},
{
"name": "stdout",
"output_type": "stream",
"text": "\nmae: 0.8905065655708313\n\neval_loss_epoch: 1.321727176507314\n\nrmse: 1.1496639320423596\n\nmse: 1.3217271566390991\n\nImprovement of 0.0017175674438476562 observed, resetting counter. \nEarly stopping counter: 0/2\n\nStarting epoch 21\n"
},
{
"name": "stderr",
"output_type": "stream",
"text": "100%|██████████| 971/971 [00:22<00:00, 43.01it/s]\n"
},
{
"name": "stdout",
"output_type": "stream",
"text": "\ntrain_loss_epoch: 0.25342329616833914\n"
},
{
"name": "stderr",
"output_type": "stream",
"text": "100%|██████████| 6/6 [00:00<00:00, 14.97it/s]\n"
},
{
"name": "stdout",
"output_type": "stream",
"text": "\nmae: 0.8901656270027161\n\neval_loss_epoch: 1.3163430293401082\n\nrmse: 1.1473199506138374\n\nmse: 1.316343069076538\nNo improvement above threshold observed, incrementing counter. \nEarly stopping counter: 1/2\n\nStarting epoch 22\n"
},
{
"name": "stderr",
"output_type": "stream",
"text": "100%|██████████| 971/971 [00:22<00:00, 42.68it/s]\n"
},
{
"name": "stdout",
"output_type": "stream",
"text": "\ntrain_loss_epoch: 0.2163390919612193\n"
},
{
"name": "stderr",
"output_type": "stream",
"text": "100%|██████████| 6/6 [00:00<00:00, 14.88it/s]\n"
},
{
"name": "stdout",
"output_type": "stream",
"text": "\nmae: 0.8915655612945557\n\neval_loss_epoch: 1.3194273908933003\n\nrmse: 1.148663297500658\n\nmse: 1.3194273710250854\nNo improvement above threshold observed, incrementing counter. \nEarly stopping counter: 2/2\nStopping training due to no improvement after 2 epochs\nFinishing training run\nLoading checkpoint with mae: 0.8901656270027161\n"
}
]
},
{
"metadata": {},
"cell_type": "markdown",
"source": "Comparing this to our baseline, we can see that there is an improvement!"
},
{
"metadata": {},
"cell_type": "markdown",
"source": "## Sequential recommendations using a transformer"
},
{
"metadata": {},
"cell_type": "markdown",
"source": "Using matrix factorization, we are treating each rating as being independent from the ratings around it; however, incorporating information about other movies that a user recently rated could provide an additional signal that could boost performance. For example, suppose that a user is watching a trilogy of films; if they have rated the first two instalments highly, it is likely that they may do the same for the finale!\n\nOne way that we can approach this is to use a transformer network, specifically the encoder portion, to encode additional context into the learned embeddings for each movie, and then using a fully connected neural network to make the rating predictions."
},
{
"metadata": {},
"cell_type": "markdown",
"source": "### Pre-processing the data"
},
{
"metadata": {},
"cell_type": "markdown",
"source": "The first step is to process our data so that we have a time-sorted list of movies for each user. Let's start by grouping all the ratings by user:"
},
{
"metadata": {
"trusted": false
},
"cell_type": "code",
"source": "grouped_ratings = ratings_df.sort_values(by='unix_timestamp').groupby('user_id').agg(tuple).reset_index()",
"execution_count": 28,
"outputs": []
},
{
"metadata": {
"trusted": false
},
"cell_type": "code",
"source": "grouped_ratings",
"execution_count": 29,
"outputs": [
{
"data": {
"text/html": "<div>\n<style scoped>\n .dataframe tbody tr th:only-of-type {\n vertical-align: middle;\n }\n\n .dataframe tbody tr th {\n vertical-align: top;\n }\n\n .dataframe thead th {\n text-align: right;\n }\n</style>\n<table border=\"1\" class=\"dataframe\">\n <thead>\n <tr style=\"text-align: right;\">\n <th></th>\n <th>user_id</th>\n <th>title</th>\n <th>rating</th>\n <th>unix_timestamp</th>\n <th>is_valid</th>\n </tr>\n </thead>\n <tbody>\n <tr>\n <th>0</th>\n <td>1</td>\n <td>(Girl, Interrupted (1999), Cinderella (1950), ...</td>\n <td>(4, 5, 4, 5, 3, 5, 4, 4, 5, 4, 5, 3, 4, 4, 4, ...</td>\n <td>(978300019, 978300055, 978300055, 978300055, 9...</td>\n <td>(False, False, False, False, False, False, Fal...</td>\n </tr>\n <tr>\n <th>1</th>\n <td>10</td>\n <td>(Godfather, The (1972), Pretty Woman (1990), S...</td>\n <td>(3, 4, 3, 4, 4, 3, 5, 5, 5, 3, 3, 4, 5, 4, 4, ...</td>\n <td>(978224375, 978224375, 978224375, 978224400, 9...</td>\n <td>(False, False, False, False, False, False, Fal...</td>\n </tr>\n <tr>\n <th>2</th>\n <td>100</td>\n <td>(Starship Troopers (1997), Star Wars: Episode ...</td>\n <td>(3, 4, 4, 3, 4, 3, 1, 1, 5, 4, 4, 3, 4, 2, 3, ...</td>\n <td>(977593595, 977593595, 977593607, 977593624, 9...</td>\n <td>(False, False, False, False, False, False, Fal...</td>\n </tr>\n <tr>\n <th>3</th>\n <td>1000</td>\n <td>(Cat on a Hot Tin Roof (1958), Licence to Kill...</td>\n <td>(4, 4, 5, 3, 5, 5, 2, 5, 4, 4, 5, 3, 5, 5, 5, ...</td>\n <td>(975040566, 975040566, 975040566, 975040629, 9...</td>\n <td>(False, False, False, False, False, False, Fal...</td>\n </tr>\n <tr>\n <th>4</th>\n <td>1001</td>\n <td>(Raiders of the Lost Ark (1981), Guinevere (19...</td>\n <td>(4, 4, 4, 2, 2, 1, 5, 4, 5, 4, 4, 4, 4, 3, 4, ...</td>\n <td>(975039591, 975039702, 975039702, 975039898, 9...</td>\n <td>(False, False, False, False, False, False, Fal...</td>\n </tr>\n <tr>\n <th>...</th>\n <td>...</td>\n <td>...</td>\n <td>...</td>\n <td>...</td>\n <td>...</td>\n </tr>\n <tr>\n <th>6035</th>\n <td>995</td>\n <td>(Six Days Seven Nights (1998), Star Wars: Epis...</td>\n <td>(2, 4, 5, 4, 3, 3, 4, 4, 3, 5, 5, 5, 5, 5, 5, ...</td>\n <td>(975054785, 975054785, 975054785, 975054853, 9...</td>\n <td>(False, False, False, False, False, False, Fal...</td>\n </tr>\n <tr>\n <th>6036</th>\n <td>996</td>\n <td>(Nightmare on Elm Street, A (1984), St. Elmo's...</td>\n <td>(4, 3, 5, 3, 5, 5, 5, 5, 4, 2, 5, 5, 5, 4, 5, ...</td>\n <td>(975052132, 975052132, 975052195, 975052284, 9...</td>\n <td>(False, False, False, False, False, False, Fal...</td>\n </tr>\n <tr>\n <th>6037</th>\n <td>997</td>\n <td>(Star Wars: Episode V - The Empire Strikes Bac...</td>\n <td>(4, 3, 3, 3, 2, 5, 5, 5, 4, 4, 5, 4, 4, 3, 4, ...</td>\n <td>(975044235, 975044425, 975044426, 975044426, 9...</td>\n <td>(False, False, False, False, False, False, Fal...</td>\n </tr>\n <tr>\n <th>6038</th>\n <td>998</td>\n <td>(Butcher's Wife, The (1991), E.T. the Extra-Te...</td>\n <td>(3, 5, 4, 5, 3, 4, 4, 3, 4, 4, 4, 4, 4, 5, 4, ...</td>\n <td>(975043499, 975043593, 975043593, 975043593, 9...</td>\n <td>(False, False, False, False, False, False, Fal...</td>\n </tr>\n <tr>\n <th>6039</th>\n <td>999</td>\n <td>(Star Wars: Episode V - The Empire Strikes Bac...</td>\n <td>(5, 3, 1, 2, 4, 4, 5, 5, 4, 5, 5, 4, 4, 5, 4, ...</td>\n <td>(975042787, 975042921, 975043058, 975043058, 9...</td>\n <td>(False, False, False, False, False, False, Fal...</td>\n </tr>\n </tbody>\n</table>\n<p>6040 rows × 5 columns</p>\n</div>",
"text/plain": " user_id title \\\n0 1 (Girl, Interrupted (1999), Cinderella (1950), ... \n1 10 (Godfather, The (1972), Pretty Woman (1990), S... \n2 100 (Starship Troopers (1997), Star Wars: Episode ... \n3 1000 (Cat on a Hot Tin Roof (1958), Licence to Kill... \n4 1001 (Raiders of the Lost Ark (1981), Guinevere (19... \n... ... ... \n6035 995 (Six Days Seven Nights (1998), Star Wars: Epis... \n6036 996 (Nightmare on Elm Street, A (1984), St. Elmo's... \n6037 997 (Star Wars: Episode V - The Empire Strikes Bac... \n6038 998 (Butcher's Wife, The (1991), E.T. the Extra-Te... \n6039 999 (Star Wars: Episode V - The Empire Strikes Bac... \n\n rating \\\n0 (4, 5, 4, 5, 3, 5, 4, 4, 5, 4, 5, 3, 4, 4, 4, ... \n1 (3, 4, 3, 4, 4, 3, 5, 5, 5, 3, 3, 4, 5, 4, 4, ... \n2 (3, 4, 4, 3, 4, 3, 1, 1, 5, 4, 4, 3, 4, 2, 3, ... \n3 (4, 4, 5, 3, 5, 5, 2, 5, 4, 4, 5, 3, 5, 5, 5, ... \n4 (4, 4, 4, 2, 2, 1, 5, 4, 5, 4, 4, 4, 4, 3, 4, ... \n... ... \n6035 (2, 4, 5, 4, 3, 3, 4, 4, 3, 5, 5, 5, 5, 5, 5, ... \n6036 (4, 3, 5, 3, 5, 5, 5, 5, 4, 2, 5, 5, 5, 4, 5, ... \n6037 (4, 3, 3, 3, 2, 5, 5, 5, 4, 4, 5, 4, 4, 3, 4, ... \n6038 (3, 5, 4, 5, 3, 4, 4, 3, 4, 4, 4, 4, 4, 5, 4, ... \n6039 (5, 3, 1, 2, 4, 4, 5, 5, 4, 5, 5, 4, 4, 5, 4, ... \n\n unix_timestamp \\\n0 (978300019, 978300055, 978300055, 978300055, 9... \n1 (978224375, 978224375, 978224375, 978224400, 9... \n2 (977593595, 977593595, 977593607, 977593624, 9... \n3 (975040566, 975040566, 975040566, 975040629, 9... \n4 (975039591, 975039702, 975039702, 975039898, 9... \n... ... \n6035 (975054785, 975054785, 975054785, 975054853, 9... \n6036 (975052132, 975052132, 975052195, 975052284, 9... \n6037 (975044235, 975044425, 975044426, 975044426, 9... \n6038 (975043499, 975043593, 975043593, 975043593, 9... \n6039 (975042787, 975042921, 975043058, 975043058, 9... \n\n is_valid \n0 (False, False, False, False, False, False, Fal... \n1 (False, False, False, False, False, False, Fal... \n2 (False, False, False, False, False, False, Fal... \n3 (False, False, False, False, False, False, Fal... \n4 (False, False, False, False, False, False, Fal... \n... ... \n6035 (False, False, False, False, False, False, Fal... \n6036 (False, False, False, False, False, False, Fal... \n6037 (False, False, False, False, False, False, Fal... \n6038 (False, False, False, False, False, False, Fal... \n6039 (False, False, False, False, False, False, Fal... \n\n[6040 rows x 5 columns]"
},
"execution_count": 29,
"metadata": {},
"output_type": "execute_result"
}
]
},
{
"metadata": {},
"cell_type": "markdown",
"source": "Now that we have grouped by user, we can create an additional column so that we can see the number of events associated with each user"
},
{
"metadata": {
"trusted": false
},
"cell_type": "code",
"source": "grouped_ratings['num_ratings'] = grouped_ratings['rating'].apply(lambda row: len(row))",
"execution_count": 30,
"outputs": []
},
{
"metadata": {},
"cell_type": "markdown",
"source": "Let's take a look at the new dataframe"
},
{
"metadata": {
"trusted": false
},
"cell_type": "code",
"source": "grouped_ratings",
"execution_count": 31,
"outputs": [
{
"data": {
"text/html": "<div>\n<style scoped>\n .dataframe tbody tr th:only-of-type {\n vertical-align: middle;\n }\n\n .dataframe tbody tr th {\n vertical-align: top;\n }\n\n .dataframe thead th {\n text-align: right;\n }\n</style>\n<table border=\"1\" class=\"dataframe\">\n <thead>\n <tr style=\"text-align: right;\">\n <th></th>\n <th>user_id</th>\n <th>title</th>\n <th>rating</th>\n <th>unix_timestamp</th>\n <th>is_valid</th>\n <th>num_ratings</th>\n </tr>\n </thead>\n <tbody>\n <tr>\n <th>0</th>\n <td>1</td>\n <td>(Girl, Interrupted (1999), Cinderella (1950), ...</td>\n <td>(4, 5, 4, 5, 3, 5, 4, 4, 5, 4, 5, 3, 4, 4, 4, ...</td>\n <td>(978300019, 978300055, 978300055, 978300055, 9...</td>\n <td>(False, False, False, False, False, False, Fal...</td>\n <td>53</td>\n </tr>\n <tr>\n <th>1</th>\n <td>10</td>\n <td>(Godfather, The (1972), Pretty Woman (1990), S...</td>\n <td>(3, 4, 3, 4, 4, 3, 5, 5, 5, 3, 3, 4, 5, 4, 4, ...</td>\n <td>(978224375, 978224375, 978224375, 978224400, 9...</td>\n <td>(False, False, False, False, False, False, Fal...</td>\n <td>401</td>\n </tr>\n <tr>\n <th>2</th>\n <td>100</td>\n <td>(Starship Troopers (1997), Star Wars: Episode ...</td>\n <td>(3, 4, 4, 3, 4, 3, 1, 1, 5, 4, 4, 3, 4, 2, 3, ...</td>\n <td>(977593595, 977593595, 977593607, 977593624, 9...</td>\n <td>(False, False, False, False, False, False, Fal...</td>\n <td>76</td>\n </tr>\n <tr>\n <th>3</th>\n <td>1000</td>\n <td>(Cat on a Hot Tin Roof (1958), Licence to Kill...</td>\n <td>(4, 4, 5, 3, 5, 5, 2, 5, 4, 4, 5, 3, 5, 5, 5, ...</td>\n <td>(975040566, 975040566, 975040566, 975040629, 9...</td>\n <td>(False, False, False, False, False, False, Fal...</td>\n <td>84</td>\n </tr>\n <tr>\n <th>4</th>\n <td>1001</td>\n <td>(Raiders of the Lost Ark (1981), Guinevere (19...</td>\n <td>(4, 4, 4, 2, 2, 1, 5, 4, 5, 4, 4, 4, 4, 3, 4, ...</td>\n <td>(975039591, 975039702, 975039702, 975039898, 9...</td>\n <td>(False, False, False, False, False, False, Fal...</td>\n <td>377</td>\n </tr>\n <tr>\n <th>...</th>\n <td>...</td>\n <td>...</td>\n <td>...</td>\n <td>...</td>\n <td>...</td>\n <td>...</td>\n </tr>\n <tr>\n <th>6035</th>\n <td>995</td>\n <td>(Six Days Seven Nights (1998), Star Wars: Epis...</td>\n <td>(2, 4, 5, 4, 3, 3, 4, 4, 3, 5, 5, 5, 5, 5, 5, ...</td>\n <td>(975054785, 975054785, 975054785, 975054853, 9...</td>\n <td>(False, False, False, False, False, False, Fal...</td>\n <td>49</td>\n </tr>\n <tr>\n <th>6036</th>\n <td>996</td>\n <td>(Nightmare on Elm Street, A (1984), St. Elmo's...</td>\n <td>(4, 3, 5, 3, 5, 5, 5, 5, 4, 2, 5, 5, 5, 4, 5, ...</td>\n <td>(975052132, 975052132, 975052195, 975052284, 9...</td>\n <td>(False, False, False, False, False, False, Fal...</td>\n <td>296</td>\n </tr>\n <tr>\n <th>6037</th>\n <td>997</td>\n <td>(Star Wars: Episode V - The Empire Strikes Bac...</td>\n <td>(4, 3, 3, 3, 2, 5, 5, 5, 4, 4, 5, 4, 4, 3, 4, ...</td>\n <td>(975044235, 975044425, 975044426, 975044426, 9...</td>\n <td>(False, False, False, False, False, False, Fal...</td>\n <td>30</td>\n </tr>\n <tr>\n <th>6038</th>\n <td>998</td>\n <td>(Butcher's Wife, The (1991), E.T. the Extra-Te...</td>\n <td>(3, 5, 4, 5, 3, 4, 4, 3, 4, 4, 4, 4, 4, 5, 4, ...</td>\n <td>(975043499, 975043593, 975043593, 975043593, 9...</td>\n <td>(False, False, False, False, False, False, Fal...</td>\n <td>135</td>\n </tr>\n <tr>\n <th>6039</th>\n <td>999</td>\n <td>(Star Wars: Episode V - The Empire Strikes Bac...</td>\n <td>(5, 3, 1, 2, 4, 4, 5, 5, 4, 5, 5, 4, 4, 5, 4, ...</td>\n <td>(975042787, 975042921, 975043058, 975043058, 9...</td>\n <td>(False, False, False, False, False, False, Fal...</td>\n <td>412</td>\n </tr>\n </tbody>\n</table>\n<p>6040 rows × 6 columns</p>\n</div>",
"text/plain": " user_id title \\\n0 1 (Girl, Interrupted (1999), Cinderella (1950), ... \n1 10 (Godfather, The (1972), Pretty Woman (1990), S... \n2 100 (Starship Troopers (1997), Star Wars: Episode ... \n3 1000 (Cat on a Hot Tin Roof (1958), Licence to Kill... \n4 1001 (Raiders of the Lost Ark (1981), Guinevere (19... \n... ... ... \n6035 995 (Six Days Seven Nights (1998), Star Wars: Epis... \n6036 996 (Nightmare on Elm Street, A (1984), St. Elmo's... \n6037 997 (Star Wars: Episode V - The Empire Strikes Bac... \n6038 998 (Butcher's Wife, The (1991), E.T. the Extra-Te... \n6039 999 (Star Wars: Episode V - The Empire Strikes Bac... \n\n rating \\\n0 (4, 5, 4, 5, 3, 5, 4, 4, 5, 4, 5, 3, 4, 4, 4, ... \n1 (3, 4, 3, 4, 4, 3, 5, 5, 5, 3, 3, 4, 5, 4, 4, ... \n2 (3, 4, 4, 3, 4, 3, 1, 1, 5, 4, 4, 3, 4, 2, 3, ... \n3 (4, 4, 5, 3, 5, 5, 2, 5, 4, 4, 5, 3, 5, 5, 5, ... \n4 (4, 4, 4, 2, 2, 1, 5, 4, 5, 4, 4, 4, 4, 3, 4, ... \n... ... \n6035 (2, 4, 5, 4, 3, 3, 4, 4, 3, 5, 5, 5, 5, 5, 5, ... \n6036 (4, 3, 5, 3, 5, 5, 5, 5, 4, 2, 5, 5, 5, 4, 5, ... \n6037 (4, 3, 3, 3, 2, 5, 5, 5, 4, 4, 5, 4, 4, 3, 4, ... \n6038 (3, 5, 4, 5, 3, 4, 4, 3, 4, 4, 4, 4, 4, 5, 4, ... \n6039 (5, 3, 1, 2, 4, 4, 5, 5, 4, 5, 5, 4, 4, 5, 4, ... \n\n unix_timestamp \\\n0 (978300019, 978300055, 978300055, 978300055, 9... \n1 (978224375, 978224375, 978224375, 978224400, 9... \n2 (977593595, 977593595, 977593607, 977593624, 9... \n3 (975040566, 975040566, 975040566, 975040629, 9... \n4 (975039591, 975039702, 975039702, 975039898, 9... \n... ... \n6035 (975054785, 975054785, 975054785, 975054853, 9... \n6036 (975052132, 975052132, 975052195, 975052284, 9... \n6037 (975044235, 975044425, 975044426, 975044426, 9... \n6038 (975043499, 975043593, 975043593, 975043593, 9... \n6039 (975042787, 975042921, 975043058, 975043058, 9... \n\n is_valid num_ratings \n0 (False, False, False, False, False, False, Fal... 53 \n1 (False, False, False, False, False, False, Fal... 401 \n2 (False, False, False, False, False, False, Fal... 76 \n3 (False, False, False, False, False, False, Fal... 84 \n4 (False, False, False, False, False, False, Fal... 377 \n... ... ... \n6035 (False, False, False, False, False, False, Fal... 49 \n6036 (False, False, False, False, False, False, Fal... 296 \n6037 (False, False, False, False, False, False, Fal... 30 \n6038 (False, False, False, False, False, False, Fal... 135 \n6039 (False, False, False, False, False, False, Fal... 412 \n\n[6040 rows x 6 columns]"
},
"execution_count": 31,
"metadata": {},
"output_type": "execute_result"
}
]
},
{
"metadata": {},
"cell_type": "markdown",
"source": "Now that we have grouped all the ratings for each user, let's divide these into smaller sequences. To make the most out of the data, we would like the model to have the opportunity to predict a rating for every movie in the training set. To do this, let's specify a sequence length s and use the previous s-1 ratings as our user history.\n\nAs the model expects each sequence to be a fixed length, we will fill empty spaces with a padding token, so that sequences can be batched and passed to the model. Let's create a function to do this.\n\nWe are going to arbitrarily choose a length of 10 here."
},
{
"metadata": {
"trusted": false
},
"cell_type": "code",
"source": "sequence_length = 10",
"execution_count": 32,
"outputs": []
},
{
"metadata": {
"trusted": false
},
"cell_type": "code",
"source": "def create_sequences(values, sequence_length):\n sequences = []\n for i, v in enumerate(values):\n seq = values[:i+1]\n if len(seq) > sequence_length:\n seq = seq[i-sequence_length+1:i+1]\n elif len(seq) < sequence_length:\n seq =(*(['[PAD]'] * (sequence_length - len(seq))), *seq)\n \n sequences.append(seq)\n return sequences\n ",
"execution_count": 33,
"outputs": []
},
{
"metadata": {},
"cell_type": "markdown",
"source": "To visualize how this function works, let's apply it, with a sequence length of 3, to the first 10 movies rated by the first user. These movies are:"
},
{
"metadata": {
"trusted": false
},
"cell_type": "code",
"source": "grouped_ratings.iloc[0]['title'][:10]",
"execution_count": 34,
"outputs": [
{
"data": {
"text/plain": "('Girl, Interrupted (1999)',\n 'Cinderella (1950)',\n 'Titanic (1997)',\n 'Back to the Future (1985)',\n 'Meet Joe Black (1998)',\n 'Last Days of Disco, The (1998)',\n 'Erin Brockovich (2000)',\n 'To Kill a Mockingbird (1962)',\n 'Christmas Story, A (1983)',\n 'Star Wars: Episode IV - A New Hope (1977)')"
},
"execution_count": 34,
"metadata": {},
"output_type": "execute_result"
}
]
},
{
"metadata": {},
"cell_type": "markdown",
"source": "Applying our function, we have:"
},
{
"metadata": {
"trusted": false
},
"cell_type": "code",
"source": "create_sequences(grouped_ratings.iloc[0]['title'][:10], 3)",
"execution_count": 35,
"outputs": [
{
"data": {
"text/plain": "[('[PAD]', '[PAD]', 'Girl, Interrupted (1999)'),\n ('[PAD]', 'Girl, Interrupted (1999)', 'Cinderella (1950)'),\n ('Girl, Interrupted (1999)', 'Cinderella (1950)', 'Titanic (1997)'),\n ('Cinderella (1950)', 'Titanic (1997)', 'Back to the Future (1985)'),\n ('Titanic (1997)', 'Back to the Future (1985)', 'Meet Joe Black (1998)'),\n ('Back to the Future (1985)',\n 'Meet Joe Black (1998)',\n 'Last Days of Disco, The (1998)'),\n ('Meet Joe Black (1998)',\n 'Last Days of Disco, The (1998)',\n 'Erin Brockovich (2000)'),\n ('Last Days of Disco, The (1998)',\n 'Erin Brockovich (2000)',\n 'To Kill a Mockingbird (1962)'),\n ('Erin Brockovich (2000)',\n 'To Kill a Mockingbird (1962)',\n 'Christmas Story, A (1983)'),\n ('To Kill a Mockingbird (1962)',\n 'Christmas Story, A (1983)',\n 'Star Wars: Episode IV - A New Hope (1977)')]"
},
"execution_count": 35,
"metadata": {},
"output_type": "execute_result"
}
]
},
{
"metadata": {},
"cell_type": "markdown",
"source": "As we can see, we have 10 sequences of length 3, where the final movie in the sequence is unchanged from the original list."
},
{
"metadata": {},
"cell_type": "markdown",
"source": "Now, let's apply this function to all of the features in our dataframe"
},
{
"metadata": {
"trusted": false
},
"cell_type": "code",
"source": "grouped_cols = ['title', 'rating', 'unix_timestamp', 'is_valid'] \nfor col in grouped_cols:\n grouped_ratings[col] = grouped_ratings[col].apply(lambda x: create_sequences(x, sequence_length))",
"execution_count": 36,
"outputs": []
},
{
"metadata": {
"trusted": false
},
"cell_type": "code",
"source": "grouped_ratings.head(2)",
"execution_count": 37,
"outputs": [
{
"data": {
"text/html": "<div>\n<style scoped>\n .dataframe tbody tr th:only-of-type {\n vertical-align: middle;\n }\n\n .dataframe tbody tr th {\n vertical-align: top;\n }\n\n .dataframe thead th {\n text-align: right;\n }\n</style>\n<table border=\"1\" class=\"dataframe\">\n <thead>\n <tr style=\"text-align: right;\">\n <th></th>\n <th>user_id</th>\n <th>title</th>\n <th>rating</th>\n <th>unix_timestamp</th>\n <th>is_valid</th>\n <th>num_ratings</th>\n </tr>\n </thead>\n <tbody>\n <tr>\n <th>0</th>\n <td>1</td>\n <td>[([PAD], [PAD], [PAD], [PAD], [PAD], [PAD], [P...</td>\n <td>[([PAD], [PAD], [PAD], [PAD], [PAD], [PAD], [P...</td>\n <td>[([PAD], [PAD], [PAD], [PAD], [PAD], [PAD], [P...</td>\n <td>[([PAD], [PAD], [PAD], [PAD], [PAD], [PAD], [P...</td>\n <td>53</td>\n </tr>\n <tr>\n <th>1</th>\n <td>10</td>\n <td>[([PAD], [PAD], [PAD], [PAD], [PAD], [PAD], [P...</td>\n <td>[([PAD], [PAD], [PAD], [PAD], [PAD], [PAD], [P...</td>\n <td>[([PAD], [PAD], [PAD], [PAD], [PAD], [PAD], [P...</td>\n <td>[([PAD], [PAD], [PAD], [PAD], [PAD], [PAD], [P...</td>\n <td>401</td>\n </tr>\n </tbody>\n</table>\n</div>",
"text/plain": " user_id title \\\n0 1 [([PAD], [PAD], [PAD], [PAD], [PAD], [PAD], [P... \n1 10 [([PAD], [PAD], [PAD], [PAD], [PAD], [PAD], [P... \n\n rating \\\n0 [([PAD], [PAD], [PAD], [PAD], [PAD], [PAD], [P... \n1 [([PAD], [PAD], [PAD], [PAD], [PAD], [PAD], [P... \n\n unix_timestamp \\\n0 [([PAD], [PAD], [PAD], [PAD], [PAD], [PAD], [P... \n1 [([PAD], [PAD], [PAD], [PAD], [PAD], [PAD], [P... \n\n is_valid num_ratings \n0 [([PAD], [PAD], [PAD], [PAD], [PAD], [PAD], [P... 53 \n1 [([PAD], [PAD], [PAD], [PAD], [PAD], [PAD], [P... 401 "
},
"execution_count": 37,
"metadata": {},
"output_type": "execute_result"
}
]
},
{
"metadata": {},
"cell_type": "markdown",
"source": "Currently, we have one row that contains all the sequences for a certain user. However, during training, we would like to create batches made up of sequences from many different users. To do this, we will have to transform the data so that each sequence has its own row, while remaining associated with the user ID. We can use the pandas 'explode' function for each feature, and then aggregate these DataFrames together."
},
{
"metadata": {
"trusted": false
},
"cell_type": "code",
"source": "exploded_ratings = grouped_ratings[['user_id', 'title']].explode('title', ignore_index=True)\ndfs = [grouped_ratings[[col]].explode(col, ignore_index=True) for col in grouped_cols[1:]]\nseq_df = pd.concat([exploded_ratings, *dfs], axis=1)",
"execution_count": 38,
"outputs": []
},
{
"metadata": {
"trusted": false
},
"cell_type": "code",
"source": "seq_df.head()",
"execution_count": 39,
"outputs": [
{
"data": {
"text/html": "<div>\n<style scoped>\n .dataframe tbody tr th:only-of-type {\n vertical-align: middle;\n }\n\n .dataframe tbody tr th {\n vertical-align: top;\n }\n\n .dataframe thead th {\n text-align: right;\n }\n</style>\n<table border=\"1\" class=\"dataframe\">\n <thead>\n <tr style=\"text-align: right;\">\n <th></th>\n <th>user_id</th>\n <th>title</th>\n <th>rating</th>\n <th>unix_timestamp</th>\n <th>is_valid</th>\n </tr>\n </thead>\n <tbody>\n <tr>\n <th>0</th>\n <td>1</td>\n <td>([PAD], [PAD], [PAD], [PAD], [PAD], [PAD], [PA...</td>\n <td>([PAD], [PAD], [PAD], [PAD], [PAD], [PAD], [PA...</td>\n <td>([PAD], [PAD], [PAD], [PAD], [PAD], [PAD], [PA...</td>\n <td>([PAD], [PAD], [PAD], [PAD], [PAD], [PAD], [PA...</td>\n </tr>\n <tr>\n <th>1</th>\n <td>1</td>\n <td>([PAD], [PAD], [PAD], [PAD], [PAD], [PAD], [PA...</td>\n <td>([PAD], [PAD], [PAD], [PAD], [PAD], [PAD], [PA...</td>\n <td>([PAD], [PAD], [PAD], [PAD], [PAD], [PAD], [PA...</td>\n <td>([PAD], [PAD], [PAD], [PAD], [PAD], [PAD], [PA...</td>\n </tr>\n <tr>\n <th>2</th>\n <td>1</td>\n <td>([PAD], [PAD], [PAD], [PAD], [PAD], [PAD], [PA...</td>\n <td>([PAD], [PAD], [PAD], [PAD], [PAD], [PAD], [PA...</td>\n <td>([PAD], [PAD], [PAD], [PAD], [PAD], [PAD], [PA...</td>\n <td>([PAD], [PAD], [PAD], [PAD], [PAD], [PAD], [PA...</td>\n </tr>\n <tr>\n <th>3</th>\n <td>1</td>\n <td>([PAD], [PAD], [PAD], [PAD], [PAD], [PAD], Gir...</td>\n <td>([PAD], [PAD], [PAD], [PAD], [PAD], [PAD], 4, ...</td>\n <td>([PAD], [PAD], [PAD], [PAD], [PAD], [PAD], 978...</td>\n <td>([PAD], [PAD], [PAD], [PAD], [PAD], [PAD], Fal...</td>\n </tr>\n <tr>\n <th>4</th>\n <td>1</td>\n <td>([PAD], [PAD], [PAD], [PAD], [PAD], Girl, Inte...</td>\n <td>([PAD], [PAD], [PAD], [PAD], [PAD], 4, 5, 4, 5...</td>\n <td>([PAD], [PAD], [PAD], [PAD], [PAD], 978300019,...</td>\n <td>([PAD], [PAD], [PAD], [PAD], [PAD], False, Fal...</td>\n </tr>\n </tbody>\n</table>\n</div>",
"text/plain": " user_id title \\\n0 1 ([PAD], [PAD], [PAD], [PAD], [PAD], [PAD], [PA... \n1 1 ([PAD], [PAD], [PAD], [PAD], [PAD], [PAD], [PA... \n2 1 ([PAD], [PAD], [PAD], [PAD], [PAD], [PAD], [PA... \n3 1 ([PAD], [PAD], [PAD], [PAD], [PAD], [PAD], Gir... \n4 1 ([PAD], [PAD], [PAD], [PAD], [PAD], Girl, Inte... \n\n rating \\\n0 ([PAD], [PAD], [PAD], [PAD], [PAD], [PAD], [PA... \n1 ([PAD], [PAD], [PAD], [PAD], [PAD], [PAD], [PA... \n2 ([PAD], [PAD], [PAD], [PAD], [PAD], [PAD], [PA... \n3 ([PAD], [PAD], [PAD], [PAD], [PAD], [PAD], 4, ... \n4 ([PAD], [PAD], [PAD], [PAD], [PAD], 4, 5, 4, 5... \n\n unix_timestamp \\\n0 ([PAD], [PAD], [PAD], [PAD], [PAD], [PAD], [PA... \n1 ([PAD], [PAD], [PAD], [PAD], [PAD], [PAD], [PA... \n2 ([PAD], [PAD], [PAD], [PAD], [PAD], [PAD], [PA... \n3 ([PAD], [PAD], [PAD], [PAD], [PAD], [PAD], 978... \n4 ([PAD], [PAD], [PAD], [PAD], [PAD], 978300019,... \n\n is_valid \n0 ([PAD], [PAD], [PAD], [PAD], [PAD], [PAD], [PA... \n1 ([PAD], [PAD], [PAD], [PAD], [PAD], [PAD], [PA... \n2 ([PAD], [PAD], [PAD], [PAD], [PAD], [PAD], [PA... \n3 ([PAD], [PAD], [PAD], [PAD], [PAD], [PAD], Fal... \n4 ([PAD], [PAD], [PAD], [PAD], [PAD], False, Fal... "
},
"execution_count": 39,
"metadata": {},
"output_type": "execute_result"
}
]
},
{
"metadata": {},
"cell_type": "markdown",
"source": "Now, we can see that each sequence has its own row. However, for the is_valid column, we don't care about the whole sequence and only need the last value as this is the movie for which we will be trying to predict the rating. Let's create a function to extract this value and apply it to these columns."
},
{
"metadata": {
"trusted": false
},
"cell_type": "code",
"source": "def get_last_entry(sequence):\n return sequence[-1]\n\nseq_df['is_valid'] = seq_df['is_valid'].apply(get_last_entry)",
"execution_count": 40,
"outputs": []
},
{
"metadata": {
"trusted": false
},
"cell_type": "code",
"source": "seq_df",
"execution_count": 41,
"outputs": [
{
"data": {
"text/html": "<div>\n<style scoped>\n .dataframe tbody tr th:only-of-type {\n vertical-align: middle;\n }\n\n .dataframe tbody tr th {\n vertical-align: top;\n }\n\n .dataframe thead th {\n text-align: right;\n }\n</style>\n<table border=\"1\" class=\"dataframe\">\n <thead>\n <tr style=\"text-align: right;\">\n <th></th>\n <th>user_id</th>\n <th>title</th>\n <th>rating</th>\n <th>unix_timestamp</th>\n <th>is_valid</th>\n </tr>\n </thead>\n <tbody>\n <tr>\n <th>0</th>\n <td>1</td>\n <td>([PAD], [PAD], [PAD], [PAD], [PAD], [PAD], [PA...</td>\n <td>([PAD], [PAD], [PAD], [PAD], [PAD], [PAD], [PA...</td>\n <td>([PAD], [PAD], [PAD], [PAD], [PAD], [PAD], [PA...</td>\n <td>False</td>\n </tr>\n <tr>\n <th>1</th>\n <td>1</td>\n <td>([PAD], [PAD], [PAD], [PAD], [PAD], [PAD], [PA...</td>\n <td>([PAD], [PAD], [PAD], [PAD], [PAD], [PAD], [PA...</td>\n <td>([PAD], [PAD], [PAD], [PAD], [PAD], [PAD], [PA...</td>\n <td>False</td>\n </tr>\n <tr>\n <th>2</th>\n <td>1</td>\n <td>([PAD], [PAD], [PAD], [PAD], [PAD], [PAD], [PA...</td>\n <td>([PAD], [PAD], [PAD], [PAD], [PAD], [PAD], [PA...</td>\n <td>([PAD], [PAD], [PAD], [PAD], [PAD], [PAD], [PA...</td>\n <td>False</td>\n </tr>\n <tr>\n <th>3</th>\n <td>1</td>\n <td>([PAD], [PAD], [PAD], [PAD], [PAD], [PAD], Gir...</td>\n <td>([PAD], [PAD], [PAD], [PAD], [PAD], [PAD], 4, ...</td>\n <td>([PAD], [PAD], [PAD], [PAD], [PAD], [PAD], 978...</td>\n <td>False</td>\n </tr>\n <tr>\n <th>4</th>\n <td>1</td>\n <td>([PAD], [PAD], [PAD], [PAD], [PAD], Girl, Inte...</td>\n <td>([PAD], [PAD], [PAD], [PAD], [PAD], 4, 5, 4, 5...</td>\n <td>([PAD], [PAD], [PAD], [PAD], [PAD], 978300019,...</td>\n <td>False</td>\n </tr>\n <tr>\n <th>...</th>\n <td>...</td>\n <td>...</td>\n <td>...</td>\n <td>...</td>\n <td>...</td>\n </tr>\n <tr>\n <th>1000204</th>\n <td>999</td>\n <td>(General's Daughter, The (1999), Powder (1995)...</td>\n <td>(3, 3, 2, 1, 3, 2, 3, 2, 4, 3)</td>\n <td>(975364681, 975364717, 975364717, 975364717, 9...</td>\n <td>False</td>\n </tr>\n <tr>\n <th>1000205</th>\n <td>999</td>\n <td>(Powder (1995), We're No Angels (1989), Out of...</td>\n <td>(3, 2, 1, 3, 2, 3, 2, 4, 3, 3)</td>\n <td>(975364717, 975364717, 975364717, 975364743, 9...</td>\n <td>False</td>\n </tr>\n <tr>\n <th>1000206</th>\n <td>999</td>\n <td>(We're No Angels (1989), Out of Africa (1985),...</td>\n <td>(2, 1, 3, 2, 3, 2, 4, 3, 3, 3)</td>\n <td>(975364717, 975364717, 975364743, 975364743, 9...</td>\n <td>False</td>\n </tr>\n <tr>\n <th>1000207</th>\n <td>999</td>\n <td>(Out of Africa (1985), Instinct (1999), Corrup...</td>\n <td>(1, 3, 2, 3, 2, 4, 3, 3, 3, 2)</td>\n <td>(975364717, 975364743, 975364743, 975364784, 9...</td>\n <td>False</td>\n </tr>\n <tr>\n <th>1000208</th>\n <td>999</td>\n <td>(Instinct (1999), Corruptor, The (1999), Jack ...</td>\n <td>(3, 2, 3, 2, 4, 3, 3, 3, 2, 2)</td>\n <td>(975364743, 975364743, 975364784, 975364784, 9...</td>\n <td>True</td>\n </tr>\n </tbody>\n</table>\n<p>1000209 rows × 5 columns</p>\n</div>",
"text/plain": " user_id title \\\n0 1 ([PAD], [PAD], [PAD], [PAD], [PAD], [PAD], [PA... \n1 1 ([PAD], [PAD], [PAD], [PAD], [PAD], [PAD], [PA... \n2 1 ([PAD], [PAD], [PAD], [PAD], [PAD], [PAD], [PA... \n3 1 ([PAD], [PAD], [PAD], [PAD], [PAD], [PAD], Gir... \n4 1 ([PAD], [PAD], [PAD], [PAD], [PAD], Girl, Inte... \n... ... ... \n1000204 999 (General's Daughter, The (1999), Powder (1995)... \n1000205 999 (Powder (1995), We're No Angels (1989), Out of... \n1000206 999 (We're No Angels (1989), Out of Africa (1985),... \n1000207 999 (Out of Africa (1985), Instinct (1999), Corrup... \n1000208 999 (Instinct (1999), Corruptor, The (1999), Jack ... \n\n rating \\\n0 ([PAD], [PAD], [PAD], [PAD], [PAD], [PAD], [PA... \n1 ([PAD], [PAD], [PAD], [PAD], [PAD], [PAD], [PA... \n2 ([PAD], [PAD], [PAD], [PAD], [PAD], [PAD], [PA... \n3 ([PAD], [PAD], [PAD], [PAD], [PAD], [PAD], 4, ... \n4 ([PAD], [PAD], [PAD], [PAD], [PAD], 4, 5, 4, 5... \n... ... \n1000204 (3, 3, 2, 1, 3, 2, 3, 2, 4, 3) \n1000205 (3, 2, 1, 3, 2, 3, 2, 4, 3, 3) \n1000206 (2, 1, 3, 2, 3, 2, 4, 3, 3, 3) \n1000207 (1, 3, 2, 3, 2, 4, 3, 3, 3, 2) \n1000208 (3, 2, 3, 2, 4, 3, 3, 3, 2, 2) \n\n unix_timestamp is_valid \n0 ([PAD], [PAD], [PAD], [PAD], [PAD], [PAD], [PA... False \n1 ([PAD], [PAD], [PAD], [PAD], [PAD], [PAD], [PA... False \n2 ([PAD], [PAD], [PAD], [PAD], [PAD], [PAD], [PA... False \n3 ([PAD], [PAD], [PAD], [PAD], [PAD], [PAD], 978... False \n4 ([PAD], [PAD], [PAD], [PAD], [PAD], 978300019,... False \n... ... ... \n1000204 (975364681, 975364717, 975364717, 975364717, 9... False \n1000205 (975364717, 975364717, 975364717, 975364743, 9... False \n1000206 (975364717, 975364717, 975364743, 975364743, 9... False \n1000207 (975364717, 975364743, 975364743, 975364784, 9... False \n1000208 (975364743, 975364743, 975364784, 975364784, 9... True \n\n[1000209 rows x 5 columns]"
},
"execution_count": 41,
"metadata": {},
"output_type": "execute_result"
}
]
},
{
"metadata": {},
"cell_type": "markdown",
"source": "Also, to make it easy to access the rating that we are trying to predict, let's separate this into its own column."
},
{
"metadata": {
"trusted": false
},
"cell_type": "code",
"source": "seq_df['target_rating'] = seq_df['rating'].apply(get_last_entry)\nseq_df['previous_ratings'] = seq_df['rating'].apply(lambda seq: seq[:-1])\nseq_df.drop(columns=['rating'], inplace=True)",
"execution_count": 42,
"outputs": []
},
{
"metadata": {},
"cell_type": "markdown",
"source": "To prevent the model from including padding tokens when calculating attention scores, we can provide an attention mask to the transformer; the mask should be 'True' for a padding token and 'False' otherwise. Let's calculate this for each row, as well as creating a column to show the number of padding tokens present."
},
{
"metadata": {
"trusted": false
},
"cell_type": "code",
"source": "seq_df['pad_mask'] = seq_df['title'].apply(lambda x: (np.array(x) == '[PAD]'))\nseq_df['num_pads'] = seq_df['pad_mask'].apply(sum)\nseq_df['pad_mask'] = seq_df['pad_mask'].apply(lambda x: x.tolist()) # in case we serialize later",
"execution_count": 43,
"outputs": []
},
{
"metadata": {},
"cell_type": "markdown",
"source": "Let's inspect the transformed data"
},
{
"metadata": {
"trusted": false
},
"cell_type": "code",
"source": "seq_df",
"execution_count": 44,
"outputs": [
{
"data": {
"text/html": "<div>\n<style scoped>\n .dataframe tbody tr th:only-of-type {\n vertical-align: middle;\n }\n\n .dataframe tbody tr th {\n vertical-align: top;\n }\n\n .dataframe thead th {\n text-align: right;\n }\n</style>\n<table border=\"1\" class=\"dataframe\">\n <thead>\n <tr style=\"text-align: right;\">\n <th></th>\n <th>user_id</th>\n <th>title</th>\n <th>unix_timestamp</th>\n <th>is_valid</th>\n <th>target_rating</th>\n <th>previous_ratings</th>\n <th>pad_mask</th>\n <th>num_pads</th>\n </tr>\n </thead>\n <tbody>\n <tr>\n <th>0</th>\n <td>1</td>\n <td>([PAD], [PAD], [PAD], [PAD], [PAD], [PAD], [PA...</td>\n <td>([PAD], [PAD], [PAD], [PAD], [PAD], [PAD], [PA...</td>\n <td>False</td>\n <td>4</td>\n <td>([PAD], [PAD], [PAD], [PAD], [PAD], [PAD], [PA...</td>\n <td>[True, True, True, True, True, True, True, Tru...</td>\n <td>9</td>\n </tr>\n <tr>\n <th>1</th>\n <td>1</td>\n <td>([PAD], [PAD], [PAD], [PAD], [PAD], [PAD], [PA...</td>\n <td>([PAD], [PAD], [PAD], [PAD], [PAD], [PAD], [PA...</td>\n <td>False</td>\n <td>5</td>\n <td>([PAD], [PAD], [PAD], [PAD], [PAD], [PAD], [PA...</td>\n <td>[True, True, True, True, True, True, True, Tru...</td>\n <td>8</td>\n </tr>\n <tr>\n <th>2</th>\n <td>1</td>\n <td>([PAD], [PAD], [PAD], [PAD], [PAD], [PAD], [PA...</td>\n <td>([PAD], [PAD], [PAD], [PAD], [PAD], [PAD], [PA...</td>\n <td>False</td>\n <td>4</td>\n <td>([PAD], [PAD], [PAD], [PAD], [PAD], [PAD], [PA...</td>\n <td>[True, True, True, True, True, True, True, Fal...</td>\n <td>7</td>\n </tr>\n <tr>\n <th>3</th>\n <td>1</td>\n <td>([PAD], [PAD], [PAD], [PAD], [PAD], [PAD], Gir...</td>\n <td>([PAD], [PAD], [PAD], [PAD], [PAD], [PAD], 978...</td>\n <td>False</td>\n <td>5</td>\n <td>([PAD], [PAD], [PAD], [PAD], [PAD], [PAD], 4, ...</td>\n <td>[True, True, True, True, True, True, False, Fa...</td>\n <td>6</td>\n </tr>\n <tr>\n <th>4</th>\n <td>1</td>\n <td>([PAD], [PAD], [PAD], [PAD], [PAD], Girl, Inte...</td>\n <td>([PAD], [PAD], [PAD], [PAD], [PAD], 978300019,...</td>\n <td>False</td>\n <td>3</td>\n <td>([PAD], [PAD], [PAD], [PAD], [PAD], 4, 5, 4, 5)</td>\n <td>[True, True, True, True, True, False, False, F...</td>\n <td>5</td>\n </tr>\n <tr>\n <th>...</th>\n <td>...</td>\n <td>...</td>\n <td>...</td>\n <td>...</td>\n <td>...</td>\n <td>...</td>\n <td>...</td>\n <td>...</td>\n </tr>\n <tr>\n <th>1000204</th>\n <td>999</td>\n <td>(General's Daughter, The (1999), Powder (1995)...</td>\n <td>(975364681, 975364717, 975364717, 975364717, 9...</td>\n <td>False</td>\n <td>3</td>\n <td>(3, 3, 2, 1, 3, 2, 3, 2, 4)</td>\n <td>[False, False, False, False, False, False, Fal...</td>\n <td>0</td>\n </tr>\n <tr>\n <th>1000205</th>\n <td>999</td>\n <td>(Powder (1995), We're No Angels (1989), Out of...</td>\n <td>(975364717, 975364717, 975364717, 975364743, 9...</td>\n <td>False</td>\n <td>3</td>\n <td>(3, 2, 1, 3, 2, 3, 2, 4, 3)</td>\n <td>[False, False, False, False, False, False, Fal...</td>\n <td>0</td>\n </tr>\n <tr>\n <th>1000206</th>\n <td>999</td>\n <td>(We're No Angels (1989), Out of Africa (1985),...</td>\n <td>(975364717, 975364717, 975364743, 975364743, 9...</td>\n <td>False</td>\n <td>3</td>\n <td>(2, 1, 3, 2, 3, 2, 4, 3, 3)</td>\n <td>[False, False, False, False, False, False, Fal...</td>\n <td>0</td>\n </tr>\n <tr>\n <th>1000207</th>\n <td>999</td>\n <td>(Out of Africa (1985), Instinct (1999), Corrup...</td>\n <td>(975364717, 975364743, 975364743, 975364784, 9...</td>\n <td>False</td>\n <td>2</td>\n <td>(1, 3, 2, 3, 2, 4, 3, 3, 3)</td>\n <td>[False, False, False, False, False, False, Fal...</td>\n <td>0</td>\n </tr>\n <tr>\n <th>1000208</th>\n <td>999</td>\n <td>(Instinct (1999), Corruptor, The (1999), Jack ...</td>\n <td>(975364743, 975364743, 975364784, 975364784, 9...</td>\n <td>True</td>\n <td>2</td>\n <td>(3, 2, 3, 2, 4, 3, 3, 3, 2)</td>\n <td>[False, False, False, False, False, False, Fal...</td>\n <td>0</td>\n </tr>\n </tbody>\n</table>\n<p>1000209 rows × 8 columns</p>\n</div>",
"text/plain": " user_id title \\\n0 1 ([PAD], [PAD], [PAD], [PAD], [PAD], [PAD], [PA... \n1 1 ([PAD], [PAD], [PAD], [PAD], [PAD], [PAD], [PA... \n2 1 ([PAD], [PAD], [PAD], [PAD], [PAD], [PAD], [PA... \n3 1 ([PAD], [PAD], [PAD], [PAD], [PAD], [PAD], Gir... \n4 1 ([PAD], [PAD], [PAD], [PAD], [PAD], Girl, Inte... \n... ... ... \n1000204 999 (General's Daughter, The (1999), Powder (1995)... \n1000205 999 (Powder (1995), We're No Angels (1989), Out of... \n1000206 999 (We're No Angels (1989), Out of Africa (1985),... \n1000207 999 (Out of Africa (1985), Instinct (1999), Corrup... \n1000208 999 (Instinct (1999), Corruptor, The (1999), Jack ... \n\n unix_timestamp is_valid \\\n0 ([PAD], [PAD], [PAD], [PAD], [PAD], [PAD], [PA... False \n1 ([PAD], [PAD], [PAD], [PAD], [PAD], [PAD], [PA... False \n2 ([PAD], [PAD], [PAD], [PAD], [PAD], [PAD], [PA... False \n3 ([PAD], [PAD], [PAD], [PAD], [PAD], [PAD], 978... False \n4 ([PAD], [PAD], [PAD], [PAD], [PAD], 978300019,... False \n... ... ... \n1000204 (975364681, 975364717, 975364717, 975364717, 9... False \n1000205 (975364717, 975364717, 975364717, 975364743, 9... False \n1000206 (975364717, 975364717, 975364743, 975364743, 9... False \n1000207 (975364717, 975364743, 975364743, 975364784, 9... False \n1000208 (975364743, 975364743, 975364784, 975364784, 9... True \n\n target_rating previous_ratings \\\n0 4 ([PAD], [PAD], [PAD], [PAD], [PAD], [PAD], [PA... \n1 5 ([PAD], [PAD], [PAD], [PAD], [PAD], [PAD], [PA... \n2 4 ([PAD], [PAD], [PAD], [PAD], [PAD], [PAD], [PA... \n3 5 ([PAD], [PAD], [PAD], [PAD], [PAD], [PAD], 4, ... \n4 3 ([PAD], [PAD], [PAD], [PAD], [PAD], 4, 5, 4, 5) \n... ... ... \n1000204 3 (3, 3, 2, 1, 3, 2, 3, 2, 4) \n1000205 3 (3, 2, 1, 3, 2, 3, 2, 4, 3) \n1000206 3 (2, 1, 3, 2, 3, 2, 4, 3, 3) \n1000207 2 (1, 3, 2, 3, 2, 4, 3, 3, 3) \n1000208 2 (3, 2, 3, 2, 4, 3, 3, 3, 2) \n\n pad_mask num_pads \n0 [True, True, True, True, True, True, True, Tru... 9 \n1 [True, True, True, True, True, True, True, Tru... 8 \n2 [True, True, True, True, True, True, True, Fal... 7 \n3 [True, True, True, True, True, True, False, Fa... 6 \n4 [True, True, True, True, True, False, False, F... 5 \n... ... ... \n1000204 [False, False, False, False, False, False, Fal... 0 \n1000205 [False, False, False, False, False, False, Fal... 0 \n1000206 [False, False, False, False, False, False, Fal... 0 \n1000207 [False, False, False, False, False, False, Fal... 0 \n1000208 [False, False, False, False, False, False, Fal... 0 \n\n[1000209 rows x 8 columns]"
},
"execution_count": 44,
"metadata": {},
"output_type": "execute_result"
}
]
},
{
"metadata": {},
"cell_type": "markdown",
"source": "All looks as it should! Let's split this into training and validation sets and save this."
},
{
"metadata": {
"trusted": false
},
"cell_type": "code",
"source": "train_seq_df = seq_df[seq_df.is_valid == False]\nvalid_seq_df = seq_df[seq_df.is_valid == True]",
"execution_count": 45,
"outputs": []
},
{
"metadata": {},
"cell_type": "markdown",
"source": "### Training the model"
},
{
"metadata": {},
"cell_type": "markdown",
"source": "As we saw previously, before we can feed this data into the model, we need to create lookup tables to encode our movies and users. However, this time, we need to include the padding token in our movie lookup."
},
{
"metadata": {
"trusted": false
},
"cell_type": "code",
"source": "user_lookup = {v: i+1 for i, v in enumerate(ratings_df['user_id'].unique())}",
"execution_count": 46,
"outputs": []
},
{
"metadata": {
"trusted": false
},
"cell_type": "code",
"source": "def create_feature_lookup(df, feature):\n lookup = {v: i+1 for i, v in enumerate(df[feature].unique())}\n lookup['[PAD]'] = 0\n return lookup",
"execution_count": 47,
"outputs": []
},
{
"metadata": {
"trusted": false
},
"cell_type": "code",
"source": "movie_lookup = create_feature_lookup(ratings_df, 'title')",
"execution_count": 48,
"outputs": []
},
{
"metadata": {},
"cell_type": "markdown",
"source": "Now, we are dealing with sequences of ratings, rather than individual ones, so we will need to create a new dataset to wrap our processed DataFrame:"
},
{
"metadata": {
"trusted": false
},
"cell_type": "code",
"source": "class MovieSequenceDataset(Dataset):\n def __init__(self, df, movie_lookup, user_lookup):\n super().__init__()\n self.df = df\n self.movie_lookup = movie_lookup\n self.user_lookup = user_lookup\n\n def __len__(self):\n return len(self.df)\n\n def __getitem__(self, index):\n data = self.df.iloc[index]\n user_id = self.user_lookup[str(data.user_id)]\n movie_ids = torch.tensor([self.movie_lookup[title] for title in data.title])\n\n previous_ratings = torch.tensor(\n [rating if rating != \"[PAD]\" else 0 for rating in data.previous_ratings]\n )\n\n attention_mask = torch.tensor(data.pad_mask)\n target_rating = data.target_rating\n encoded_features = {\n \"user_id\": user_id,\n \"movie_ids\": movie_ids,\n \"ratings\": previous_ratings,\n }\n\n return (encoded_features, attention_mask), torch.tensor(\n target_rating, dtype=torch.float32\n )\n",
"execution_count": 49,
"outputs": []
},
{
"metadata": {
"trusted": false
},
"cell_type": "code",
"source": "train_dataset = MovieSequenceDataset(train_seq_df, movie_lookup, user_lookup)\nvalid_dataset = MovieSequenceDataset(valid_seq_df, movie_lookup, user_lookup)",
"execution_count": 50,
"outputs": []
},
{
"metadata": {},
"cell_type": "markdown",
"source": "Now, let's define our transformer model! As a start, given that the matrix factorization model can achieve good performance using only the user and movie ids, let's only include this information for now."
},
{
"metadata": {
"trusted": false
},
"cell_type": "code",
"source": "class BstTransformer(nn.Module):\n def __init__(\n self,\n movies_num_unique,\n users_num_unique,\n sequence_length=10,\n embedding_size=120,\n num_transformer_layers=1,\n ratings_range=(0.5, 5.5),\n ):\n super().__init__()\n self.sequence_length = sequence_length\n self.y_range = ratings_range\n self.movies_embeddings = nn.Embedding(\n movies_num_unique + 1, embedding_size, padding_idx=0\n )\n self.user_embeddings = nn.Embedding(users_num_unique + 1, embedding_size)\n self.position_embeddings = nn.Embedding(sequence_length, embedding_size)\n\n self.encoder = nn.TransformerEncoder(\n encoder_layer=nn.TransformerEncoderLayer(\n d_model=embedding_size,\n nhead=12,\n dropout=0.1,\n batch_first=True,\n activation=\"gelu\",\n ),\n num_layers=num_transformer_layers,\n )\n\n self.linear = nn.Sequential(\n nn.Linear(\n embedding_size + (embedding_size * sequence_length),\n 1024,\n ),\n nn.BatchNorm1d(1024),\n nn.Mish(),\n nn.Linear(1024, 512),\n nn.BatchNorm1d(512),\n nn.Mish(),\n nn.Dropout(0.2),\n nn.Linear(512, 256),\n nn.BatchNorm1d(256),\n nn.Mish(),\n nn.Linear(256, 1),\n nn.Sigmoid(),\n )\n\n def forward(self, inputs):\n features, mask = inputs\n\n encoded_user_id = self.user_embeddings(features[\"user_id\"])\n\n user_features = encoded_user_id\n\n encoded_movies = self.movies_embeddings(features[\"movie_ids\"])\n\n positions = torch.arange(\n 0, self.sequence_length, 1, dtype=int, device=features[\"movie_ids\"].device\n )\n positions = self.position_embeddings(positions)\n\n transformer_features = encoded_movies + positions\n\n transformer_output = self.encoder(\n transformer_features, src_key_padding_mask=mask\n )\n transformer_output = torch.flatten(transformer_output, start_dim=1)\n\n combined_output = torch.cat((transformer_output, user_features), dim=1)\n\n rating = self.linear(combined_output)\n rating = rating.squeeze()\n if self.y_range is None:\n return rating\n else:\n return rating * (self.y_range[1] - self.y_range[0]) + self.y_range[0]\n",
"execution_count": 51,
"outputs": []
},
{
"metadata": {},
"cell_type": "markdown",
"source": "We can see that, as a default, we feed our sequence of movie embeddings into a single transformer layer, before concatenating the output with the user features - here, just the user ID - and using this as the input to a fully connected network. Here, we are using only a simple positional encoding that is learned to represent the sequence in which the movies were rated; using a sine- and cosine-based approach provided no benefit during my experiments, but feel free to try it out if you are interested!\n\nOnce again, let's define a training function for this model; except for the model initialization, this is identical to the one we used to train the matrix factorization model."
},
{
"metadata": {
"trusted": false
},
"cell_type": "code",
"source": "def train_seq_model():\n model = BstTransformer(\n len(movie_lookup), len(user_lookup), sequence_length, embedding_size=120\n )\n loss_func = torch.nn.MSELoss()\n\n optimizer = torch.optim.AdamW(model.parameters(), lr=0.01)\n\n create_sched_fn = partial(\n torch.optim.lr_scheduler.OneCycleLR,\n max_lr=0.01,\n epochs=TrainerPlaceholderValues.NUM_EPOCHS,\n steps_per_epoch=TrainerPlaceholderValues.NUM_UPDATE_STEPS_PER_EPOCH,\n )\n\n trainer = Trainer(\n model=model,\n loss_func=loss_func,\n optimizer=optimizer,\n callbacks=(\n RecommenderMetricsCallback,\n *DEFAULT_CALLBACKS,\n SaveBestModelCallback(watch_metric=\"mae\"),\n EarlyStoppingCallback(\n early_stopping_patience=2,\n early_stopping_threshold=0.001,\n watch_metric=\"mae\",\n ),\n ),\n )\n\n trainer.train(\n train_dataset=train_dataset,\n eval_dataset=valid_dataset,\n num_epochs=10,\n per_device_batch_size=512,\n create_scheduler_fn=create_sched_fn,\n )\n",
"execution_count": 52,
"outputs": []
},
{
"metadata": {
"scrolled": true,
"trusted": false
},
"cell_type": "code",
"source": "notebook_launcher(train_seq_model, num_processes=2)",
"execution_count": 69,
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": "Launching a training on 2 GPUs.\n\nStarting training run\n\nStarting epoch 1\n"
},
{
"name": "stderr",
"output_type": "stream",
"text": "100%|██████████| 971/971 [00:44<00:00, 21.81it/s]\n"
},
{
"name": "stdout",
"output_type": "stream",
"text": "\ntrain_loss_epoch: 0.9955023087630188\n"
},
{
"name": "stderr",
"output_type": "stream",
"text": "100%|██████████| 6/6 [00:00<00:00, 10.53it/s]\n"
},
{
"name": "stdout",
"output_type": "stream",
"text": "\nmae: 0.7939572930335999\n\nmse: 0.9927792549133301\n\neval_loss_epoch: 0.9927793244520823\n\nrmse: 0.9963830864247597\n\nStarting epoch 2\n"
},
{
"name": "stderr",
"output_type": "stream",
"text": "100%|██████████| 971/971 [00:44<00:00, 21.87it/s]\n"
},
{
"name": "stdout",
"output_type": "stream",
"text": "\ntrain_loss_epoch: 0.8509480904722557\n"
},
{
"name": "stderr",
"output_type": "stream",
"text": "100%|██████████| 6/6 [00:00<00:00, 10.69it/s]\n"
},
{
"name": "stdout",
"output_type": "stream",
"text": "\nmae: 0.7802140116691589\n\nmse: 0.9521594047546387\n\neval_loss_epoch: 0.9521594146887461\n\nrmse: 0.9757865569655275\n\nImprovement of 0.013743281364440918 observed, resetting counter. \nEarly stopping counter: 0/2\n\nStarting epoch 3\n"
},
{
"name": "stderr",
"output_type": "stream",
"text": "100%|██████████| 971/971 [00:44<00:00, 21.90it/s]\n"
},
{
"name": "stdout",
"output_type": "stream",
"text": "\ntrain_loss_epoch: 0.8159997655974603\n"
},
{
"name": "stderr",
"output_type": "stream",
"text": "100%|██████████| 6/6 [00:00<00:00, 10.80it/s]\n"
},
{
"name": "stdout",
"output_type": "stream",
"text": "\nmae: 0.7579830288887024\n\nmse: 0.915351152420044\n\neval_loss_epoch: 0.9153511722882589\n\nrmse: 0.9567398561887364\n\nImprovement of 0.022230982780456543 observed, resetting counter. \nEarly stopping counter: 0/2\n\nStarting epoch 4\n"
},
{
"name": "stderr",
"output_type": "stream",
"text": "100%|██████████| 971/971 [00:44<00:00, 21.79it/s]\n"
},
{
"name": "stdout",
"output_type": "stream",
"text": "\ntrain_loss_epoch: 0.7925404456322274\n"
},
{
"name": "stderr",
"output_type": "stream",
"text": "100%|██████████| 6/6 [00:00<00:00, 10.75it/s]\n"
},
{
"name": "stdout",
"output_type": "stream",
"text": "\nmae: 0.7406826615333557\n\nmse: 0.8825389742851257\n\neval_loss_epoch: 0.8825389941533407\n\nrmse: 0.9394354550926454\n\nImprovement of 0.01730036735534668 observed, resetting counter. \nEarly stopping counter: 0/2\n\nStarting epoch 5\n"
},
{
"name": "stderr",
"output_type": "stream",
"text": "100%|██████████| 971/971 [00:44<00:00, 21.96it/s]\n"
},
{
"name": "stdout",
"output_type": "stream",
"text": "\ntrain_loss_epoch: 0.7654586238546793\n"
},
{
"name": "stderr",
"output_type": "stream",
"text": "100%|██████████| 6/6 [00:00<00:00, 10.43it/s]\n"
},
{
"name": "stdout",
"output_type": "stream",
"text": "\nmae: 0.7357890009880066\n\nmse: 0.8756368160247803\n\neval_loss_epoch: 0.8756367762883505\n\nrmse: 0.935754677265778\n\nImprovement of 0.004893660545349121 observed, resetting counter. \nEarly stopping counter: 0/2\n\nStarting epoch 6\n"
},
{
"name": "stderr",
"output_type": "stream",
"text": "100%|██████████| 971/971 [00:44<00:00, 21.81it/s]\n"
},
{
"name": "stdout",
"output_type": "stream",
"text": "\ntrain_loss_epoch: 0.7475585157912988\n"
},
{
"name": "stderr",
"output_type": "stream",
"text": "100%|██████████| 6/6 [00:00<00:00, 10.07it/s]\n"
},
{
"name": "stdout",
"output_type": "stream",
"text": "\nmae: 0.7258664965629578\n\nmse: 0.8621974587440491\n\neval_loss_epoch: 0.8621974885463715\n\nrmse: 0.9285458840273049\n\nImprovement of 0.009922504425048828 observed, resetting counter. \nEarly stopping counter: 0/2\n\nStarting epoch 7\n"
},
{
"name": "stderr",
"output_type": "stream",
"text": "100%|██████████| 971/971 [00:44<00:00, 21.92it/s]\n"
},
{
"name": "stdout",
"output_type": "stream",
"text": "\ntrain_loss_epoch: 0.7325706990199772\n"
},
{
"name": "stderr",
"output_type": "stream",
"text": "100%|██████████| 6/6 [00:00<00:00, 10.68it/s]\n"
},
{
"name": "stdout",
"output_type": "stream",
"text": "\nmae: 0.7262701988220215\n\nmse: 0.8640387654304504\n\neval_loss_epoch: 0.8640388051668803\n\nrmse: 0.9295368553373505\nNo improvement above threshold observed, incrementing counter. \nEarly stopping counter: 1/2\n\nStarting epoch 8\n"
},
{
"name": "stderr",
"output_type": "stream",
"text": "100%|██████████| 971/971 [00:44<00:00, 21.80it/s]\n"
},
{
"name": "stdout",
"output_type": "stream",
"text": "\ntrain_loss_epoch: 0.7145110318393099\n"
},
{
"name": "stderr",
"output_type": "stream",
"text": "100%|██████████| 6/6 [00:00<00:00, 10.11it/s]\n"
},
{
"name": "stdout",
"output_type": "stream",
"text": "\nmae: 0.7315012812614441\n\nmse: 0.8688936829566956\n\neval_loss_epoch: 0.8688936630884806\n\nrmse: 0.9321446684698119\nNo improvement above threshold observed, incrementing counter. \nEarly stopping counter: 2/2\nStopping training due to no improvement after 2 epochs\nFinishing training run\nLoading checkpoint with mae: 0.7258664965629578\n"
}
]
},
{
"metadata": {},
"cell_type": "markdown",
"source": "We can see that this is a significant improvement over the matrix factorization approach!"
},
{
"metadata": {},
"cell_type": "markdown",
"source": "### Adding additional data"
},
{
"metadata": {},
"cell_type": "markdown",
"source": "So far, we have only considered the user ID and a sequence of movie IDs to predict the rating; it seems likely that including information about the previous ratings made by the user would improve performance. Thankfully, this is easy to do, and the data is already being returned by our dataset. Let's tweak our architecture to include this:"
},
{
"metadata": {
"trusted": false
},
"cell_type": "code",
"source": "class BstTransformer(nn.Module):\n def __init__(\n self,\n movies_num_unique,\n users_num_unique,\n sequence_length=10,\n embedding_size=120,\n num_transformer_layers=1,\n ratings_range=(0.5, 5.5),\n ):\n super().__init__()\n self.sequence_length = sequence_length\n self.y_range = ratings_range\n self.movies_embeddings = nn.Embedding(\n movies_num_unique + 1, embedding_size, padding_idx=0\n )\n self.user_embeddings = nn.Embedding(users_num_unique + 1, embedding_size)\n self.ratings_embeddings = nn.Embedding(6, embedding_size, padding_idx=0)\n self.position_embeddings = nn.Embedding(sequence_length, embedding_size)\n\n self.encoder = nn.TransformerEncoder(\n encoder_layer=nn.TransformerEncoderLayer(\n d_model=embedding_size,\n nhead=12,\n dropout=0.1,\n batch_first=True,\n activation=\"gelu\",\n ),\n num_layers=num_transformer_layers,\n )\n\n self.linear = nn.Sequential(\n nn.Linear(\n embedding_size + (embedding_size * sequence_length),\n 1024,\n ),\n nn.BatchNorm1d(1024),\n nn.Mish(),\n nn.Linear(1024, 512),\n nn.BatchNorm1d(512),\n nn.Mish(),\n nn.Dropout(0.2),\n nn.Linear(512, 256),\n nn.BatchNorm1d(256),\n nn.Mish(),\n nn.Linear(256, 1),\n nn.Sigmoid(),\n )\n\n def forward(self, inputs):\n features, mask = inputs\n\n encoded_user_id = self.user_embeddings(features[\"user_id\"])\n\n user_features = encoded_user_id\n\n movie_history = features[\"movie_ids\"][:, :-1]\n target_movie = features[\"movie_ids\"][:, -1]\n\n ratings = self.ratings_embeddings(features[\"ratings\"])\n\n encoded_movies = self.movies_embeddings(movie_history)\n encoded_target_movie = self.movies_embeddings(target_movie)\n\n positions = torch.arange(\n 0,\n self.sequence_length - 1,\n 1,\n dtype=int,\n device=features[\"movie_ids\"].device,\n )\n positions = self.position_embeddings(positions)\n\n encoded_sequence_movies_with_position_and_rating = (\n encoded_movies + ratings + positions\n )\n encoded_target_movie = encoded_target_movie.unsqueeze(1)\n\n transformer_features = torch.cat(\n (encoded_sequence_movies_with_position_and_rating, encoded_target_movie),\n dim=1,\n )\n transformer_output = self.encoder(\n transformer_features, src_key_padding_mask=mask\n )\n transformer_output = torch.flatten(transformer_output, start_dim=1)\n\n combined_output = torch.cat((transformer_output, user_features), dim=1)\n\n rating = self.linear(combined_output)\n rating = rating.squeeze()\n if self.y_range is None:\n return rating\n else:\n return rating * (self.y_range[1] - self.y_range[0]) + self.y_range[0]\n",
"execution_count": 53,
"outputs": []
},
{
"metadata": {},
"cell_type": "markdown",
"source": "We can see that, to use the ratings data, we have added an additional embedding layer. For each previously rated movie, we then add together the movie embedding, the positional encoding and the rating embedding before feeding this sequence into the transformer. Alternatively, the rating data could be concatenated to, or multiplied with, the movie embedding, but adding them together worked the best out of the approaches that I tried.\n\nAs Jupyter maintains a live state for each class definition, we don't need to update our training function; the new class will be used when we launch training:"
},
{
"metadata": {
"scrolled": true,
"trusted": false
},
"cell_type": "code",
"source": "notebook_launcher(train_seq_model, num_processes=2)",
"execution_count": 71,
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": "Launching a training on 2 GPUs.\n\nStarting training run\n\nStarting epoch 1\n"
},
{
"name": "stderr",
"output_type": "stream",
"text": "100%|██████████| 971/971 [00:44<00:00, 21.60it/s]\n"
},
{
"name": "stdout",
"output_type": "stream",
"text": "\ntrain_loss_epoch: 0.9109353272111973\n"
},
{
"name": "stderr",
"output_type": "stream",
"text": "100%|██████████| 6/6 [00:00<00:00, 10.39it/s]\n"
},
{
"name": "stdout",
"output_type": "stream",
"text": "\nmae: 0.8022098541259766\n\nmse: 0.9802775979042053\n\neval_loss_epoch: 0.9802776078383127\n\nrmse: 0.9900896918482716\n\nStarting epoch 2\n"
},
{
"name": "stderr",
"output_type": "stream",
"text": "100%|██████████| 971/971 [00:44<00:00, 21.80it/s]\n"
},
{
"name": "stdout",
"output_type": "stream",
"text": "\ntrain_loss_epoch: 0.8358323996393614\n"
},
{
"name": "stderr",
"output_type": "stream",
"text": "100%|██████████| 6/6 [00:00<00:00, 10.25it/s]\n"
},
{
"name": "stdout",
"output_type": "stream",
"text": "\nmae: 0.7573742866516113\n\nmse: 0.9179417490959167\n\neval_loss_epoch: 0.9179418087005615\n\nrmse: 0.9580927664354412\n\nImprovement of 0.044835567474365234 observed, resetting counter. \nEarly stopping counter: 0/2\n\nStarting epoch 3\n"
},
{
"name": "stderr",
"output_type": "stream",
"text": "100%|██████████| 971/971 [00:44<00:00, 21.76it/s]\n"
},
{
"name": "stdout",
"output_type": "stream",
"text": "\ntrain_loss_epoch: 0.8017225273482954\n"
},
{
"name": "stderr",
"output_type": "stream",
"text": "100%|██████████| 6/6 [00:00<00:00, 10.69it/s]\n"
},
{
"name": "stdout",
"output_type": "stream",
"text": "\nmae: 0.7416232228279114\n\nmse: 0.8967887759208679\n\neval_loss_epoch: 0.8967887858549753\n\nrmse: 0.9469893219677126\n\nImprovement of 0.01575106382369995 observed, resetting counter. \nEarly stopping counter: 0/2\n\nStarting epoch 4\n"
},
{
"name": "stderr",
"output_type": "stream",
"text": "100%|██████████| 971/971 [00:44<00:00, 21.73it/s]\n"
},
{
"name": "stdout",
"output_type": "stream",
"text": "\ntrain_loss_epoch: 0.7820610726898657\n"
},
{
"name": "stderr",
"output_type": "stream",
"text": "100%|██████████| 6/6 [00:00<00:00, 10.47it/s]\n"
},
{
"name": "stdout",
"output_type": "stream",
"text": "\nmae: 0.7375993728637695\n\nmse: 0.8765184283256531\n\neval_loss_epoch: 0.8765184382597605\n\nrmse: 0.9362256289621925\n\nImprovement of 0.004023849964141846 observed, resetting counter. \nEarly stopping counter: 0/2\n\nStarting epoch 5\n"
},
{
"name": "stderr",
"output_type": "stream",
"text": "100%|██████████| 971/971 [00:44<00:00, 21.60it/s]\n"
},
{
"name": "stdout",
"output_type": "stream",
"text": "\ntrain_loss_epoch: 0.7703093529729715\n"
},
{
"name": "stderr",
"output_type": "stream",
"text": "100%|██████████| 6/6 [00:00<00:00, 10.07it/s]\n"
},
{
"name": "stdout",
"output_type": "stream",
"text": "\nmae: 0.7289111018180847\n\nmse: 0.8735694885253906\n\neval_loss_epoch: 0.8735695282618204\n\nrmse: 0.9346493933691877\n\nImprovement of 0.008688271045684814 observed, resetting counter. \nEarly stopping counter: 0/2\n\nStarting epoch 6\n"
},
{
"name": "stderr",
"output_type": "stream",
"text": "100%|██████████| 971/971 [00:44<00:00, 21.65it/s]\n"
},
{
"name": "stdout",
"output_type": "stream",
"text": "\ntrain_loss_epoch: 0.7511685453777333\n"
},
{
"name": "stderr",
"output_type": "stream",
"text": "100%|██████████| 6/6 [00:00<00:00, 9.98it/s]\n"
},
{
"name": "stdout",
"output_type": "stream",
"text": "\nmae: 0.7231311798095703\n\nmse: 0.8583566546440125\n\neval_loss_epoch: 0.8583566149075826\n\nrmse: 0.926475393436875\n\nImprovement of 0.005779922008514404 observed, resetting counter. \nEarly stopping counter: 0/2\n\nStarting epoch 7\n"
},
{
"name": "stderr",
"output_type": "stream",
"text": "100%|██████████| 971/971 [00:44<00:00, 21.72it/s]\n"
},
{
"name": "stdout",
"output_type": "stream",
"text": "\ntrain_loss_epoch: 0.7281422661089872\n"
},
{
"name": "stderr",
"output_type": "stream",
"text": "100%|██████████| 6/6 [00:00<00:00, 10.58it/s]\n"
},
{
"name": "stdout",
"output_type": "stream",
"text": "\nmae: 0.7262148261070251\n\nmse: 0.8491864204406738\n\neval_loss_epoch: 0.849186360836029\n\nrmse: 0.9215131146330332\nNo improvement above threshold observed, incrementing counter. \nEarly stopping counter: 1/2\n\nStarting epoch 8\n"
},
{
"name": "stderr",
"output_type": "stream",
"text": "100%|██████████| 971/971 [00:44<00:00, 21.64it/s]\n"
},
{
"name": "stdout",
"output_type": "stream",
"text": "\ntrain_loss_epoch: 0.709694542980587\n"
},
{
"name": "stderr",
"output_type": "stream",
"text": "100%|██████████| 6/6 [00:00<00:00, 10.26it/s]\n"
},
{
"name": "stdout",
"output_type": "stream",
"text": "\nmae: 0.7182666659355164\n\nmse: 0.8506280779838562\n\neval_loss_epoch: 0.8506280283133189\n\nrmse: 0.9222950059410797\n\nImprovement of 0.004864513874053955 observed, resetting counter. \nEarly stopping counter: 0/2\n\nStarting epoch 9\n"
},
{
"name": "stderr",
"output_type": "stream",
"text": "100%|██████████| 971/971 [00:44<00:00, 21.70it/s]\n"
},
{
"name": "stdout",
"output_type": "stream",
"text": "\ntrain_loss_epoch: 0.6928928330556003\n"
},
{
"name": "stderr",
"output_type": "stream",
"text": "100%|██████████| 6/6 [00:00<00:00, 10.42it/s]\n"
},
{
"name": "stdout",
"output_type": "stream",
"text": "\nmae: 0.7204784750938416\n\nmse: 0.8569676876068115\n\neval_loss_epoch: 0.8569677571455637\n\nrmse: 0.9257254925769364\nNo improvement above threshold observed, incrementing counter. \nEarly stopping counter: 1/2\n\nStarting epoch 10\n"
},
{
"name": "stderr",
"output_type": "stream",
"text": "100%|██████████| 971/971 [00:44<00:00, 21.75it/s]\n"
},
{
"name": "stdout",
"output_type": "stream",
"text": "\ntrain_loss_epoch: 0.6806765550442508\n"
},
{
"name": "stderr",
"output_type": "stream",
"text": "100%|██████████| 6/6 [00:00<00:00, 10.45it/s]\n"
},
{
"name": "stdout",
"output_type": "stream",
"text": "\nmae: 0.7206871509552002\n\nmse: 0.8620250225067139\n\neval_loss_epoch: 0.8620249728361765\n\nrmse: 0.9284530265483084\nNo improvement above threshold observed, incrementing counter. \nEarly stopping counter: 2/2\nStopping training due to no improvement after 2 epochs\nFinishing training run\nLoading checkpoint with mae: 0.7182666659355164\n"
}
]
},
{
"metadata": {},
"cell_type": "markdown",
"source": "We can see that incorporating the ratings data has improved our results slightly!"
},
{
"metadata": {},
"cell_type": "markdown",
"source": "### Adding user features"
},
{
"metadata": {},
"cell_type": "markdown",
"source": "In addition to the ratings data, we also have more information about the users that we could add into the model. To remind ourselves, let's take a look at the users table:"
},
{
"metadata": {
"scrolled": false,
"trusted": false
},
"cell_type": "code",
"source": "users",
"execution_count": 54,
"outputs": [
{
"data": {
"text/html": "<div>\n<style scoped>\n .dataframe tbody tr th:only-of-type {\n vertical-align: middle;\n }\n\n .dataframe tbody tr th {\n vertical-align: top;\n }\n\n .dataframe thead th {\n text-align: right;\n }\n</style>\n<table border=\"1\" class=\"dataframe\">\n <thead>\n <tr style=\"text-align: right;\">\n <th></th>\n <th>user_id</th>\n <th>sex</th>\n <th>age_group</th>\n <th>occupation</th>\n <th>zip_code</th>\n </tr>\n </thead>\n <tbody>\n <tr>\n <th>0</th>\n <td>1</td>\n <td>F</td>\n <td>1</td>\n <td>10</td>\n <td>48067</td>\n </tr>\n <tr>\n <th>1</th>\n <td>2</td>\n <td>M</td>\n <td>56</td>\n <td>16</td>\n <td>70072</td>\n </tr>\n <tr>\n <th>2</th>\n <td>3</td>\n <td>M</td>\n <td>25</td>\n <td>15</td>\n <td>55117</td>\n </tr>\n <tr>\n <th>3</th>\n <td>4</td>\n <td>M</td>\n <td>45</td>\n <td>7</td>\n <td>02460</td>\n </tr>\n <tr>\n <th>4</th>\n <td>5</td>\n <td>M</td>\n <td>25</td>\n <td>20</td>\n <td>55455</td>\n </tr>\n <tr>\n <th>...</th>\n <td>...</td>\n <td>...</td>\n <td>...</td>\n <td>...</td>\n <td>...</td>\n </tr>\n <tr>\n <th>6035</th>\n <td>6036</td>\n <td>F</td>\n <td>25</td>\n <td>15</td>\n <td>32603</td>\n </tr>\n <tr>\n <th>6036</th>\n <td>6037</td>\n <td>F</td>\n <td>45</td>\n <td>1</td>\n <td>76006</td>\n </tr>\n <tr>\n <th>6037</th>\n <td>6038</td>\n <td>F</td>\n <td>56</td>\n <td>1</td>\n <td>14706</td>\n </tr>\n <tr>\n <th>6038</th>\n <td>6039</td>\n <td>F</td>\n <td>45</td>\n <td>0</td>\n <td>01060</td>\n </tr>\n <tr>\n <th>6039</th>\n <td>6040</td>\n <td>M</td>\n <td>25</td>\n <td>6</td>\n <td>11106</td>\n </tr>\n </tbody>\n</table>\n<p>6040 rows × 5 columns</p>\n</div>",
"text/plain": " user_id sex age_group occupation zip_code\n0 1 F 1 10 48067\n1 2 M 56 16 70072\n2 3 M 25 15 55117\n3 4 M 45 7 02460\n4 5 M 25 20 55455\n... ... .. ... ... ...\n6035 6036 F 25 15 32603\n6036 6037 F 45 1 76006\n6037 6038 F 56 1 14706\n6038 6039 F 45 0 01060\n6039 6040 M 25 6 11106\n\n[6040 rows x 5 columns]"
},
"execution_count": 54,
"metadata": {},
"output_type": "execute_result"
}
]
},
{
"metadata": {},
"cell_type": "markdown",
"source": "Let's try adding in the categorical variables representing the users' sex, age groups, and occupation to the model, and see if we see any improvement. While occupation looks like it is already sequentially numerically encoded, we must do the same for the sex and age_group columns. We can use the 'LabelEncoder' class from scikit-learn to do this for us, and append the encoded columns to the DataFrame:"
},
{
"metadata": {
"trusted": false
},
"cell_type": "code",
"source": "from sklearn.preprocessing import LabelEncoder",
"execution_count": 55,
"outputs": []
},
{
"metadata": {
"trusted": false
},
"cell_type": "code",
"source": "le = LabelEncoder()",
"execution_count": 56,
"outputs": []
},
{
"metadata": {
"trusted": false
},
"cell_type": "code",
"source": "users['sex_encoded'] = le.fit_transform(users.sex)",
"execution_count": 57,
"outputs": []
},
{
"metadata": {
"trusted": false
},
"cell_type": "code",
"source": "users['age_group_encoded'] = le.fit_transform(users.age_group)",
"execution_count": 58,
"outputs": []
},
{
"metadata": {
"trusted": false
},
"cell_type": "code",
"source": "users[\"user_id\"] = users[\"user_id\"].astype(str)",
"execution_count": 59,
"outputs": []
},
{
"metadata": {},
"cell_type": "markdown",
"source": "Now that we have all the features that we are going to use encoded, let's join the user features to our sequences DataFrame, and update our training and validation sets."
},
{
"metadata": {
"trusted": false
},
"cell_type": "code",
"source": "seq_with_user_features = pd.merge(seq_df, users)",
"execution_count": 60,
"outputs": []
},
{
"metadata": {
"trusted": false
},
"cell_type": "code",
"source": "train_df = seq_with_user_features[seq_with_user_features.is_valid == False]\nvalid_df = seq_with_user_features[seq_with_user_features.is_valid == True]",
"execution_count": 61,
"outputs": []
},
{
"metadata": {},
"cell_type": "markdown",
"source": "Let's update our dataset to include these features."
},
{
"metadata": {
"trusted": false
},
"cell_type": "code",
"source": "class MovieSequenceDataset(Dataset):\n def __init__(self, df, movie_lookup, user_lookup):\n super().__init__()\n self.df = df\n self.movie_lookup = movie_lookup\n self.user_lookup = user_lookup\n\n def __len__(self):\n return len(self.df)\n\n def __getitem__(self, index):\n data = self.df.iloc[index]\n user_id = self.user_lookup[str(data.user_id)]\n movie_ids = torch.tensor([self.movie_lookup[title] for title in data.title])\n\n previous_ratings = torch.tensor(\n [rating if rating != \"[PAD]\" else 0 for rating in data.previous_ratings]\n )\n\n attention_mask = torch.tensor(data.pad_mask)\n target_rating = data.target_rating\n encoded_features = {\n \"user_id\": user_id,\n \"movie_ids\": movie_ids,\n \"ratings\": previous_ratings,\n \"age_group\": data[\"age_group_encoded\"],\n \"sex\": data[\"sex_encoded\"],\n \"occupation\": data[\"occupation\"],\n }\n\n return (encoded_features, attention_mask), torch.tensor(\n target_rating, dtype=torch.float32\n )\n",
"execution_count": 62,
"outputs": []
},
{
"metadata": {
"trusted": false
},
"cell_type": "code",
"source": "train_dataset = MovieSequenceDataset(train_df, movie_lookup, user_lookup)\nvalid_dataset = MovieSequenceDataset(valid_df, movie_lookup, user_lookup)",
"execution_count": 63,
"outputs": []
},
{
"metadata": {},
"cell_type": "markdown",
"source": "We can now modify our architecture to include embeddings for these features and concatenate these embeddings to the output of the transformer; then we pass this into the feed-forward network."
},
{
"metadata": {
"trusted": false
},
"cell_type": "code",
"source": "class BstTransformer(nn.Module):\n def __init__(\n self,\n movies_num_unique,\n users_num_unique,\n sequence_length=10,\n embedding_size=120,\n num_transformer_layers=1,\n ratings_range=(0.5, 5.5),\n ):\n super().__init__()\n self.sequence_length = sequence_length\n self.y_range = ratings_range\n self.movies_embeddings = nn.Embedding(\n movies_num_unique + 1, embedding_size, padding_idx=0\n )\n self.user_embeddings = nn.Embedding(users_num_unique + 1, embedding_size)\n self.ratings_embeddings = nn.Embedding(6, embedding_size, padding_idx=0)\n self.position_embeddings = nn.Embedding(sequence_length, embedding_size)\n\n self.sex_embeddings = nn.Embedding(\n 3,\n 2,\n )\n self.occupation_embeddings = nn.Embedding(\n 22,\n 11,\n )\n self.age_group_embeddings = nn.Embedding(\n 8,\n 4,\n )\n\n self.encoder = nn.TransformerEncoder(\n encoder_layer=nn.TransformerEncoderLayer(\n d_model=embedding_size,\n nhead=12,\n dropout=0.1,\n batch_first=True,\n activation=\"gelu\",\n ),\n num_layers=num_transformer_layers,\n )\n\n self.linear = nn.Sequential(\n nn.Linear(\n embedding_size + (embedding_size * sequence_length) + 4 + 11 + 2,\n 1024,\n ),\n nn.BatchNorm1d(1024),\n nn.Mish(),\n nn.Linear(1024, 512),\n nn.BatchNorm1d(512),\n nn.Mish(),\n nn.Dropout(0.2),\n nn.Linear(512, 256),\n nn.BatchNorm1d(256),\n nn.Mish(),\n nn.Linear(256, 1),\n nn.Sigmoid(),\n )\n\n def forward(self, inputs):\n features, mask = inputs\n\n user_id = self.user_embeddings(features[\"user_id\"])\n\n age_group = self.age_group_embeddings(features[\"age_group\"])\n sex = self.sex_embeddings(features[\"sex\"])\n occupation = self.occupation_embeddings(features[\"occupation\"])\n\n user_features = user_features = torch.cat(\n (user_id, sex, age_group, occupation), 1\n )\n\n movie_history = features[\"movie_ids\"][:, :-1]\n target_movie = features[\"movie_ids\"][:, -1]\n\n ratings = self.ratings_embeddings(features[\"ratings\"])\n\n encoded_movies = self.movies_embeddings(movie_history)\n encoded_target_movie = self.movies_embeddings(target_movie)\n\n positions = torch.arange(\n 0,\n self.sequence_length - 1,\n 1,\n dtype=int,\n device=features[\"movie_ids\"].device,\n )\n positions = self.position_embeddings(positions)\n\n encoded_sequence_movies_with_position_and_rating = (\n encoded_movies + ratings + positions\n )\n encoded_target_movie = encoded_target_movie.unsqueeze(1)\n\n transformer_features = torch.cat(\n (encoded_sequence_movies_with_position_and_rating, encoded_target_movie),\n dim=1,\n )\n transformer_output = self.encoder(\n transformer_features, src_key_padding_mask=mask\n )\n transformer_output = torch.flatten(transformer_output, start_dim=1)\n\n combined_output = torch.cat((transformer_output, user_features), dim=1)\n\n rating = self.linear(combined_output)\n rating = rating.squeeze()\n if self.y_range is None:\n return rating\n else:\n return rating * (self.y_range[1] - self.y_range[0]) + self.y_range[0]\n",
"execution_count": 64,
"outputs": []
},
{
"metadata": {
"scrolled": true,
"trusted": false
},
"cell_type": "code",
"source": "notebook_launcher(train_seq_model, num_processes=2)",
"execution_count": 68,
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": "Launching a training on 2 GPUs.\n\nStarting training run\n\nStarting epoch 1\n"
},
{
"name": "stderr",
"output_type": "stream",
"text": "100%|██████████| 971/971 [00:46<00:00, 20.93it/s]\n"
},
{
"name": "stdout",
"output_type": "stream",
"text": "\ntrain_loss_epoch: 0.9115137239317692\n"
},
{
"name": "stderr",
"output_type": "stream",
"text": "100%|██████████| 6/6 [00:00<00:00, 9.88it/s]\n"
},
{
"name": "stdout",
"output_type": "stream",
"text": "\nmae: 0.7698847651481628\n\neval_loss_epoch: 0.9531584481398264\n\nrmse: 0.9762983956839264\n\nmse: 0.9531585574150085\n\nStarting epoch 2\n"
},
{
"name": "stderr",
"output_type": "stream",
"text": "100%|██████████| 971/971 [00:46<00:00, 21.11it/s]\n"
},
{
"name": "stdout",
"output_type": "stream",
"text": "\ntrain_loss_epoch: 0.8351250770285986\n"
},
{
"name": "stderr",
"output_type": "stream",
"text": "100%|██████████| 6/6 [00:00<00:00, 10.07it/s]\n"
},
{
"name": "stdout",
"output_type": "stream",
"text": "\nmae: 0.7515485882759094\n\neval_loss_epoch: 0.9225256244341532\n\nrmse: 0.9604819543842161\n\nmse: 0.9225255846977234\n\nImprovement of 0.018336176872253418 observed, resetting counter. \nEarly stopping counter: 0/2\n\nStarting epoch 3\n"
},
{
"name": "stderr",
"output_type": "stream",
"text": "100%|██████████| 971/971 [00:45<00:00, 21.14it/s]\n"
},
{
"name": "stdout",
"output_type": "stream",
"text": "\ntrain_loss_epoch: 0.804713054002866\n"
},
{
"name": "stderr",
"output_type": "stream",
"text": "100%|██████████| 6/6 [00:00<00:00, 9.38it/s]\n"
},
{
"name": "stdout",
"output_type": "stream",
"text": "\nmae: 0.743607223033905\n\neval_loss_epoch: 0.8977507948875427\n\nrmse: 0.9474971527620479\n\nmse: 0.8977508544921875\n\nImprovement of 0.007941365242004395 observed, resetting counter. \nEarly stopping counter: 0/2\n\nStarting epoch 4\n"
},
{
"name": "stderr",
"output_type": "stream",
"text": "100%|██████████| 971/971 [00:45<00:00, 21.14it/s]\n"
},
{
"name": "stdout",
"output_type": "stream",
"text": "\ntrain_loss_epoch: 0.7829881388421653\n"
},
{
"name": "stderr",
"output_type": "stream",
"text": "100%|██████████| 6/6 [00:00<00:00, 9.94it/s]\n"
},
{
"name": "stdout",
"output_type": "stream",
"text": "\nmae: 0.7408876419067383\n\neval_loss_epoch: 0.8879891335964203\n\nrmse: 0.9423317060300227\n\nmse: 0.8879890441894531\n\nImprovement of 0.002719581127166748 observed, resetting counter. \nEarly stopping counter: 0/2\n\nStarting epoch 5\n"
},
{
"name": "stderr",
"output_type": "stream",
"text": "100%|██████████| 971/971 [00:45<00:00, 21.21it/s]\n"
},
{
"name": "stdout",
"output_type": "stream",
"text": "\ntrain_loss_epoch: 0.7723518149000732\n"
},
{
"name": "stderr",
"output_type": "stream",
"text": "100%|██████████| 6/6 [00:00<00:00, 9.62it/s]\n"
},
{
"name": "stdout",
"output_type": "stream",
"text": "\nmae: 0.7306551337242126\n\neval_loss_epoch: 0.8741245567798615\n\nrmse: 0.9349462695671549\n\nmse: 0.8741245269775391\n\nImprovement of 0.010232508182525635 observed, resetting counter. \nEarly stopping counter: 0/2\n\nStarting epoch 6\n"
},
{
"name": "stderr",
"output_type": "stream",
"text": "100%|██████████| 971/971 [00:45<00:00, 21.19it/s]\n"
},
{
"name": "stdout",
"output_type": "stream",
"text": "\ntrain_loss_epoch: 0.7589085090418186\n"
},
{
"name": "stderr",
"output_type": "stream",
"text": "100%|██████████| 6/6 [00:00<00:00, 9.82it/s]\n"
},
{
"name": "stdout",
"output_type": "stream",
"text": "\nmae: 0.7242081761360168\n\neval_loss_epoch: 0.87059153119723\n\nrmse: 0.9330549081691162\n\nmse: 0.8705914616584778\n\nImprovement of 0.006446957588195801 observed, resetting counter. \nEarly stopping counter: 0/2\n\nStarting epoch 7\n"
},
{
"name": "stderr",
"output_type": "stream",
"text": "100%|██████████| 971/971 [00:45<00:00, 21.14it/s]\n"
},
{
"name": "stdout",
"output_type": "stream",
"text": "\ntrain_loss_epoch: 0.7346186338934422\n"
},
{
"name": "stderr",
"output_type": "stream",
"text": "100%|██████████| 6/6 [00:00<00:00, 8.77it/s]\n"
},
{
"name": "stdout",
"output_type": "stream",
"text": "\nmae: 0.7160568833351135\n\neval_loss_epoch: 0.8519508838653564\n\nrmse: 0.9230118546721686\n\nmse: 0.8519508838653564\n\nImprovement of 0.00815129280090332 observed, resetting counter. \nEarly stopping counter: 0/2\n\nStarting epoch 8\n"
},
{
"name": "stderr",
"output_type": "stream",
"text": "100%|██████████| 971/971 [00:45<00:00, 21.14it/s]\n"
},
{
"name": "stdout",
"output_type": "stream",
"text": "\ntrain_loss_epoch: 0.7128203637690794\n"
},
{
"name": "stderr",
"output_type": "stream",
"text": "100%|██████████| 6/6 [00:00<00:00, 9.90it/s]\n"
},
{
"name": "stdout",
"output_type": "stream",
"text": "\nmae: 0.7230656743049622\n\neval_loss_epoch: 0.8604253133138021\n\nrmse: 0.9275911240657637\n\nmse: 0.8604252934455872\nNo improvement above threshold observed, incrementing counter. \nEarly stopping counter: 1/2\n\nStarting epoch 9\n"
},
{
"name": "stderr",
"output_type": "stream",
"text": "100%|██████████| 971/971 [00:46<00:00, 21.07it/s]\n"
},
{
"name": "stdout",
"output_type": "stream",
"text": "\ntrain_loss_epoch: 0.6947063981074875\n"
},
{
"name": "stderr",
"output_type": "stream",
"text": "100%|██████████| 6/6 [00:00<00:00, 9.97it/s]\n"
},
{
"name": "stdout",
"output_type": "stream",
"text": "\nmae: 0.723215639591217\n\neval_loss_epoch: 0.8628555238246918\n\nrmse: 0.928900152881013\n\nmse: 0.8628554940223694\nNo improvement above threshold observed, incrementing counter. \nEarly stopping counter: 2/2\nStopping training due to no improvement after 2 epochs\nFinishing training run\nLoading checkpoint with mae: 0.7160568833351135\n"
}
]
},
{
"metadata": {},
"cell_type": "markdown",
"source": "Here, we can see a slight decrease in the MAE, but a small increase in the MSE and RMSE, so it looks like these features made a negligible difference to the overall performance."
},
{
"metadata": {},
"cell_type": "markdown",
"source": "In writing this article, my main objective has been to try and illustrate how these approaches can be used, and so I've picked the hyperparameters somewhat arbitrarily; it's likely that with some hyperparameter tweaks, and different combinations of features, these metrics can probably be improved upon!\n\nHopefully this has provided a good introduction to using both matrix factorization and transformer-based approaches in PyTorch, and how pytorch-accelerated can speed up our process when experimenting with different models!"
},
{
"metadata": {
"trusted": false
},
"cell_type": "code",
"source": "",
"execution_count": null,
"outputs": []
}
],
"metadata": {
"kernelspec": {
"name": "conda-env-py37_pytorch-py",
"display_name": "Python [conda env:py37_pytorch]",
"language": "python"
},
"language_info": {
"name": "python",
"version": "3.7.9",
"mimetype": "text/x-python",
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"pygments_lexer": "ipython3",
"nbconvert_exporter": "python",
"file_extension": ".py"
},
"gist": {
"id": "",
"data": {
"description": "Comparing matrix factorixation with transformers using pytorch-accelerated blog post.ipynb",
"public": true
}
}
},
"nbformat": 4,
"nbformat_minor": 5
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment