Skip to content

Instantly share code, notes, and snippets.

Show Gist options
  • Save radekosmulski/72294050295c6e84c5a470dea9e637d1 to your computer and use it in GitHub Desktop.
Save radekosmulski/72294050295c6e84c5a470dea9e637d1 to your computer and use it in GitHub Desktop.
{
"cells": [
{
"cell_type": "markdown",
"id": "d34ea336",
"metadata": {},
"source": [
"# Train a Merlin Two Tower model and perform offline batch inference"
]
},
{
"cell_type": "code",
"execution_count": 1,
"id": "464aee37",
"metadata": {},
"outputs": [],
"source": [
"import nvtabular as nvt\n",
"import cudf\n",
"import numpy as np"
]
},
{
"cell_type": "markdown",
"id": "f827f562",
"metadata": {},
"source": [
"# Create data"
]
},
{
"cell_type": "markdown",
"id": "6d798569",
"metadata": {},
"source": [
"Two categories of users:\n",
"* Gold users younger than 50 buy blue Porsches\n",
"* Gold users older than 50 buy red Porsches\n",
"* Silver users younger than 70 buy blue Volkswagens\n",
"* Silver users older than 70 buy red Corvettes"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "1d886bf3",
"metadata": {},
"outputs": [],
"source": [
"car_makes = ['Porsche', 'Volkswagen', 'Fiat']\n",
"\n",
"df = cudf.DataFrame(data={\n",
" 'user_type': np.array(['gold', 'silver'])[np.random.randint(0, 2, 100_000)],\n",
" 'car_color': np.array(['blue', 'red'])[np.random.randint(0, 2, 100_000)],\n",
" 'user_age': np.random.randint(30, 90, 100_000),\n",
" 'car_make': np.array(car_makes)[np.random.randint(0, 3, 100_000)],\n",
"})\n",
"\n",
"df = cudf.concat([\n",
" df[(df.user_type == 'gold') & (df.user_age < 50) & (df.car_color == 'blue') & (df.car_make == 'Porsche')],\n",
" df[(df.user_type == 'gold') & (df.user_age > 50) & (df.car_color == 'red') & (df.car_make == 'Porsche')],\n",
" df[(df.user_type == 'silver') & (df.user_age < 70) & (df.car_color == 'blue') & (df.car_make == 'Volkswagen')],\n",
" df[(df.user_type == 'silver') & (df.user_age > 70) & (df.car_color == 'red') & (df.car_make == 'Fiat')],\n",
"])"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "4fd8e8b4",
"metadata": {},
"outputs": [],
"source": [
"car_make_color = df.car_make + df.car_color\n",
"unique_car_make_color_combos = car_make_color.unique().to_pandas().to_list()\n",
"df['car_id'] = car_make_color.to_pandas().apply(lambda c: unique_car_make_color_combos.index(c))"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "f9de5db1",
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>user_type</th>\n",
" <th>car_color</th>\n",
" <th>user_age</th>\n",
" <th>car_make</th>\n",
" <th>car_id</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>80255</th>\n",
" <td>silver</td>\n",
" <td>blue</td>\n",
" <td>32</td>\n",
" <td>Volkswagen</td>\n",
" <td>3</td>\n",
" </tr>\n",
" <tr>\n",
" <th>61861</th>\n",
" <td>silver</td>\n",
" <td>red</td>\n",
" <td>89</td>\n",
" <td>Fiat</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>87295</th>\n",
" <td>silver</td>\n",
" <td>red</td>\n",
" <td>87</td>\n",
" <td>Fiat</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>29373</th>\n",
" <td>gold</td>\n",
" <td>red</td>\n",
" <td>86</td>\n",
" <td>Porsche</td>\n",
" <td>2</td>\n",
" </tr>\n",
" <tr>\n",
" <th>51651</th>\n",
" <td>silver</td>\n",
" <td>blue</td>\n",
" <td>42</td>\n",
" <td>Volkswagen</td>\n",
" <td>3</td>\n",
" </tr>\n",
" <tr>\n",
" <th>12309</th>\n",
" <td>gold</td>\n",
" <td>red</td>\n",
" <td>72</td>\n",
" <td>Porsche</td>\n",
" <td>2</td>\n",
" </tr>\n",
" <tr>\n",
" <th>50211</th>\n",
" <td>gold</td>\n",
" <td>blue</td>\n",
" <td>40</td>\n",
" <td>Porsche</td>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <th>90590</th>\n",
" <td>gold</td>\n",
" <td>red</td>\n",
" <td>65</td>\n",
" <td>Porsche</td>\n",
" <td>2</td>\n",
" </tr>\n",
" <tr>\n",
" <th>54812</th>\n",
" <td>gold</td>\n",
" <td>blue</td>\n",
" <td>44</td>\n",
" <td>Porsche</td>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <th>132</th>\n",
" <td>gold</td>\n",
" <td>red</td>\n",
" <td>83</td>\n",
" <td>Porsche</td>\n",
" <td>2</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>"
],
"text/plain": [
" user_type car_color user_age car_make car_id\n",
"80255 silver blue 32 Volkswagen 3\n",
"61861 silver red 89 Fiat 0\n",
"87295 silver red 87 Fiat 0\n",
"29373 gold red 86 Porsche 2\n",
"51651 silver blue 42 Volkswagen 3\n",
"12309 gold red 72 Porsche 2\n",
"50211 gold blue 40 Porsche 1\n",
"90590 gold red 65 Porsche 2\n",
"54812 gold blue 44 Porsche 1\n",
"132 gold red 83 Porsche 2"
]
},
"execution_count": 4,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"df.sample(n=10)"
]
},
{
"cell_type": "markdown",
"id": "579b2e1e",
"metadata": {},
"source": [
"Our model expects of us an item id (`car_id`). This will match most real life scenario.\n",
"\n",
"The id is not necessary for training the Two Tower model, but for performing offline predictions.\n",
"\n",
"In the case of a cold start problem, where we have no user car_id pairs, we can substitute the embedding for the item (`car_id`) with the mean embeddings of all the items in our catalogue."
]
},
{
"cell_type": "markdown",
"id": "f59966aa",
"metadata": {},
"source": [
"# Process data"
]
},
{
"cell_type": "markdown",
"id": "78a50712",
"metadata": {},
"source": [
"We perform some minimal preprocessing.\n",
"\n",
"The most interesting bit -- if we do not normalize the `user_age`, the model trains much more poorly and doesn't reach its expected predictive power.\n",
"\n",
"This is very interesting.\n",
"\n",
"We might want to investigate this further, why this is the case. It is likely that our model doesn't train because of the inability of NNs to deal with numerical data like `user_age`.\n",
"\n",
"Very interesting case on the need of normalizing data (I thought it was less important nowadays with all the improvements built into modern optimizers, BatchNorm, etc)."
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "33b9b8ae",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/usr/local/lib/python3.8/dist-packages/merlin/schema/tags.py:148: UserWarning: Compound tags like Tags.ITEM_ID have been deprecated and will be removed in a future version. Please use the atomic versions of these tags, like [<Tags.ITEM: 'item'>, <Tags.ID: 'id'>].\n",
" warnings.warn(\n",
"/usr/local/lib/python3.8/dist-packages/cudf/core/frame.py:384: UserWarning: The deep parameter is ignored and is only included for pandas compatibility.\n",
" warnings.warn(\n"
]
},
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>user_type</th>\n",
" <th>user_age</th>\n",
" <th>car_make</th>\n",
" <th>car_color</th>\n",
" <th>car_id</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>8579</th>\n",
" <td>1</td>\n",
" <td>-1.633240</td>\n",
" <td>2</td>\n",
" <td>1</td>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <th>11604</th>\n",
" <td>1</td>\n",
" <td>-0.026806</td>\n",
" <td>2</td>\n",
" <td>1</td>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <th>12228</th>\n",
" <td>1</td>\n",
" <td>-0.084178</td>\n",
" <td>2</td>\n",
" <td>1</td>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <th>997</th>\n",
" <td>2</td>\n",
" <td>-1.174259</td>\n",
" <td>1</td>\n",
" <td>1</td>\n",
" <td>3</td>\n",
" </tr>\n",
" <tr>\n",
" <th>13726</th>\n",
" <td>1</td>\n",
" <td>-1.633240</td>\n",
" <td>2</td>\n",
" <td>1</td>\n",
" <td>1</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>"
],
"text/plain": [
" user_type user_age car_make car_color car_id\n",
"8579 1 -1.633240 2 1 1\n",
"11604 1 -0.026806 2 1 1\n",
"12228 1 -0.084178 2 1 1\n",
"997 2 -1.174259 1 1 3\n",
"13726 1 -1.633240 2 1 1"
]
},
"execution_count": 5,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"from merlin.schema.tags import Tags\n",
"from merlin.models.utils import schema_utils\n",
"\n",
"user_type = ['user_type'] >> nvt.ops.Categorify()\n",
"user_age = ['user_age'] >> nvt.ops.Normalize()\n",
"\n",
"user_features = user_type + user_age >> nvt.ops.TagAsUserFeatures()\n",
"\n",
"car_color = ['car_color'] >> nvt.ops.Categorify()\n",
"car_make = ['car_make'] >> nvt.ops.Categorify()\n",
"\n",
"car_id = ['car_id'] >> nvt.ops.Categorify() >> nvt.ops.TagAsItemID()\n",
"car_features = car_make + car_color + car_id >> nvt.ops.TagAsItemFeatures()\n",
"\n",
"ds = nvt.Dataset(df)\n",
"wf = nvt.Workflow(user_features + car_features)\n",
"train_dataset = wf.fit_transform(ds)\n",
"\n",
"train_dataset.compute().sample(n=5)"
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "fabd277a",
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>name</th>\n",
" <th>tags</th>\n",
" <th>dtype</th>\n",
" <th>is_list</th>\n",
" <th>is_ragged</th>\n",
" <th>properties.num_buckets</th>\n",
" <th>properties.freq_threshold</th>\n",
" <th>properties.max_size</th>\n",
" <th>properties.start_index</th>\n",
" <th>properties.cat_path</th>\n",
" <th>properties.domain.min</th>\n",
" <th>properties.domain.max</th>\n",
" <th>properties.domain.name</th>\n",
" <th>properties.embedding_sizes.cardinality</th>\n",
" <th>properties.embedding_sizes.dimension</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>user_type</td>\n",
" <td>(Tags.CATEGORICAL, Tags.USER)</td>\n",
" <td>int64</td>\n",
" <td>False</td>\n",
" <td>False</td>\n",
" <td>NaN</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>.//categories/unique.user_type.parquet</td>\n",
" <td>0.0</td>\n",
" <td>2.0</td>\n",
" <td>user_type</td>\n",
" <td>3.0</td>\n",
" <td>16.0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1</th>\n",
" <td>user_age</td>\n",
" <td>(Tags.CONTINUOUS, Tags.USER)</td>\n",
" <td>float64</td>\n",
" <td>False</td>\n",
" <td>False</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2</th>\n",
" <td>car_make</td>\n",
" <td>(Tags.ITEM, Tags.CATEGORICAL)</td>\n",
" <td>int64</td>\n",
" <td>False</td>\n",
" <td>False</td>\n",
" <td>NaN</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>.//categories/unique.car_make.parquet</td>\n",
" <td>0.0</td>\n",
" <td>3.0</td>\n",
" <td>car_make</td>\n",
" <td>4.0</td>\n",
" <td>16.0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>3</th>\n",
" <td>car_color</td>\n",
" <td>(Tags.ITEM, Tags.CATEGORICAL)</td>\n",
" <td>int64</td>\n",
" <td>False</td>\n",
" <td>False</td>\n",
" <td>NaN</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>.//categories/unique.car_color.parquet</td>\n",
" <td>0.0</td>\n",
" <td>2.0</td>\n",
" <td>car_color</td>\n",
" <td>3.0</td>\n",
" <td>16.0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>4</th>\n",
" <td>car_id</td>\n",
" <td>(Tags.ITEM, Tags.ITEM_ID, Tags.CATEGORICAL, Ta...</td>\n",
" <td>int64</td>\n",
" <td>False</td>\n",
" <td>False</td>\n",
" <td>NaN</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>.//categories/unique.car_id.parquet</td>\n",
" <td>0.0</td>\n",
" <td>4.0</td>\n",
" <td>car_id</td>\n",
" <td>5.0</td>\n",
" <td>16.0</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>"
],
"text/plain": [
"[{'name': 'user_type', 'tags': {<Tags.CATEGORICAL: 'categorical'>, <Tags.USER: 'user'>}, 'properties': {'num_buckets': None, 'freq_threshold': 0, 'max_size': 0, 'start_index': 0, 'cat_path': './/categories/unique.user_type.parquet', 'domain': {'min': 0, 'max': 2, 'name': 'user_type'}, 'embedding_sizes': {'cardinality': 3, 'dimension': 16}}, 'dtype': dtype('int64'), 'is_list': False, 'is_ragged': False}, {'name': 'user_age', 'tags': {<Tags.CONTINUOUS: 'continuous'>, <Tags.USER: 'user'>}, 'properties': {}, 'dtype': dtype('float64'), 'is_list': False, 'is_ragged': False}, {'name': 'car_make', 'tags': {<Tags.ITEM: 'item'>, <Tags.CATEGORICAL: 'categorical'>}, 'properties': {'num_buckets': None, 'freq_threshold': 0, 'max_size': 0, 'start_index': 0, 'cat_path': './/categories/unique.car_make.parquet', 'domain': {'min': 0, 'max': 3, 'name': 'car_make'}, 'embedding_sizes': {'cardinality': 4, 'dimension': 16}}, 'dtype': dtype('int64'), 'is_list': False, 'is_ragged': False}, {'name': 'car_color', 'tags': {<Tags.ITEM: 'item'>, <Tags.CATEGORICAL: 'categorical'>}, 'properties': {'num_buckets': None, 'freq_threshold': 0, 'max_size': 0, 'start_index': 0, 'cat_path': './/categories/unique.car_color.parquet', 'domain': {'min': 0, 'max': 2, 'name': 'car_color'}, 'embedding_sizes': {'cardinality': 3, 'dimension': 16}}, 'dtype': dtype('int64'), 'is_list': False, 'is_ragged': False}, {'name': 'car_id', 'tags': {<Tags.ITEM: 'item'>, <Tags.ITEM_ID: 'item_id'>, <Tags.CATEGORICAL: 'categorical'>, <Tags.ID: 'id'>}, 'properties': {'num_buckets': None, 'freq_threshold': 0, 'max_size': 0, 'start_index': 0, 'cat_path': './/categories/unique.car_id.parquet', 'domain': {'min': 0, 'max': 4, 'name': 'car_id'}, 'embedding_sizes': {'cardinality': 5, 'dimension': 16}}, 'dtype': dtype('int64'), 'is_list': False, 'is_ragged': False}]"
]
},
"execution_count": 6,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"train_dataset.schema"
]
},
{
"cell_type": "markdown",
"id": "d2a8afa9",
"metadata": {},
"source": [
"# Train the model"
]
},
{
"cell_type": "markdown",
"id": "a533e2e6",
"metadata": {},
"source": [
"Time to train the model.\n",
"\n",
"We will train a model with all the goodies (`PopularityLogitsCorrection`, `logits_temperature`, etc).\n",
"\n",
"These values are good default but you might want to tweak them when training your models."
]
},
{
"cell_type": "code",
"execution_count": 7,
"id": "65ef5b3e",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"2022-09-06 03:55:54.309967: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:991] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero\n",
"2022-09-06 03:55:54.310411: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:991] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero\n",
"2022-09-06 03:55:54.310548: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:991] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero\n",
"2022-09-06 03:55:54.321597: I tensorflow/core/platform/cpu_feature_guard.cc:194] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations: SSE3 SSE4.1 SSE4.2 AVX\n",
"To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.\n",
"2022-09-06 03:55:54.322537: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:991] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero\n",
"2022-09-06 03:55:54.322704: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:991] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero\n",
"2022-09-06 03:55:54.322841: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:991] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero\n",
"2022-09-06 03:55:54.323136: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:991] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero\n",
"2022-09-06 03:55:54.323280: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:991] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero\n",
"2022-09-06 03:55:54.323417: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:991] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero\n",
"2022-09-06 03:55:54.323536: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1532] Created device /job:localhost/replica:0/task:0/device:GPU:0 with 24576 MB memory: -> device: 0, name: Quadro RTX 8000, pci bus id: 0000:08:00.0, compute capability: 7.5\n"
]
}
],
"source": [
"import tensorflow as tf\n",
"import merlin.models.tf as mm\n",
"from merlin.models.tf.dataset import BatchedDataset\n",
"from merlin.models.utils import schema_utils\n",
"from merlin.models.tf.core.transformations import PopularityLogitsCorrection\n",
"from tensorflow.keras import regularizers"
]
},
{
"cell_type": "code",
"execution_count": 8,
"id": "d3dbde82",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/usr/local/lib/python3.8/dist-packages/merlin/schema/tags.py:148: UserWarning: Compound tags like Tags.ITEM_ID have been deprecated and will be removed in a future version. Please use the atomic versions of these tags, like [<Tags.ITEM: 'item'>, <Tags.ID: 'id'>].\n",
" warnings.warn(\n"
]
},
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>freq</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>1</th>\n",
" <td>5620</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2</th>\n",
" <td>5353</td>\n",
" </tr>\n",
" <tr>\n",
" <th>3</th>\n",
" <td>2792</td>\n",
" </tr>\n",
" <tr>\n",
" <th>4</th>\n",
" <td>2647</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>"
],
"text/plain": [
" freq\n",
"1 5620\n",
"2 5353\n",
"3 2792\n",
"4 2647"
]
},
"execution_count": 8,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"item_frequencies_df = train_dataset.compute()['car_id'].value_counts().rename('freq').to_frame()\n",
"item_frequencies_df"
]
},
{
"cell_type": "code",
"execution_count": 9,
"id": "2fe1e400",
"metadata": {},
"outputs": [],
"source": [
"item_frequencies_df.loc[0] = 0 # adding entry as 0 is the idx used for unkown categories\n",
"item_frequencies_df.sort_index(inplace=True)"
]
},
{
"cell_type": "code",
"execution_count": 10,
"id": "5d093c17",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"<tf.Tensor: shape=(5,), dtype=int64, numpy=array([ 0, 5620, 5353, 2792, 2647])>"
]
},
"execution_count": 10,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"item_frequencies = tf.convert_to_tensor(\n",
" item_frequencies_df['freq'].values.get()\n",
")\n",
"item_frequencies"
]
},
{
"cell_type": "markdown",
"id": "c48da98f",
"metadata": {},
"source": [
"We just pass in the schema and the correct architecture springs to life 💫"
]
},
{
"cell_type": "code",
"execution_count": 11,
"id": "bd647f0f",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/usr/local/lib/python3.8/dist-packages/merlin/schema/tags.py:148: UserWarning: Compound tags like Tags.ITEM_ID have been deprecated and will be removed in a future version. Please use the atomic versions of these tags, like [<Tags.ITEM: 'item'>, <Tags.ID: 'id'>].\n",
" warnings.warn(\n"
]
}
],
"source": [
"post_logits = PopularityLogitsCorrection(item_frequencies, reg_factor=1, schema=train_dataset.schema)\n",
"item_retrieval_task = mm.ItemRetrievalTask(train_dataset.schema, logits_temperature=1.8, post_logits=post_logits)\n",
"\n",
"item_retrieval_task = mm.ItemRetrievalTask(train_dataset.schema)\n",
"\n",
"model = mm.TwoTowerModel(\n",
" train_dataset.schema,\n",
" query_tower=mm.MLPBlock(\n",
" [64],\n",
" activation='relu',\n",
" no_activation_last_layer=True, \n",
" dropout=0.3, \n",
" kernel_regularizer=regularizers.l2(5e-5),\n",
" bias_regularizer=regularizers.l2(5e-5),\n",
" ),\n",
" embedding_options=mm.EmbeddingOptions(\n",
" infer_embedding_sizes=True,\n",
" infer_embedding_sizes_multiplier=2,\n",
" embeddings_l2_reg=1e-5,\n",
" ),\n",
" prediction_tasks=item_retrieval_task\n",
")"
]
},
{
"cell_type": "code",
"execution_count": 12,
"id": "5cf94697",
"metadata": {},
"outputs": [],
"source": [
"metrics = [mm.TopKMetricsAggregator(mm.RecallAt(10), mm.MRRAt(10))]\n",
"\n",
"lerning_rate = tf.keras.optimizers.schedules.ExponentialDecay(\n",
" 1e-2,\n",
" decay_steps=10,\n",
" decay_rate=0.96,\n",
" staircase=True,\n",
")\n",
"optimizer = tf.keras.optimizers.Adam(learning_rate=lerning_rate)\n",
"loss = tf.keras.losses.CategoricalCrossentropy(\n",
" from_logits=True, label_smoothing=0,\n",
")\n",
"\n",
"model.compile(optimizer, loss=loss, metrics=metrics, run_eagerly=False)"
]
},
{
"cell_type": "code",
"execution_count": 13,
"id": "352659fe",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 1/6\n",
"WARNING:tensorflow:AutoGraph could not transform <bound method Socket.send of <zmq.Socket(zmq.PUSH) at 0x7f834afde280>> and will run it as-is.\n",
"Please report this to the TensorFlow team. When filing the bug, set the verbosity to 10 (on Linux, `export AUTOGRAPH_VERBOSITY=10`) and attach the full output.\n",
"Cause: module, class, method, function, traceback, frame, or code object was expected, got cython_function_or_method\n",
"To silence this warning, decorate the function with @tf.autograph.experimental.do_not_convert\n",
"WARNING: AutoGraph could not transform <bound method Socket.send of <zmq.Socket(zmq.PUSH) at 0x7f834afde280>> and will run it as-is.\n",
"Please report this to the TensorFlow team. When filing the bug, set the verbosity to 10 (on Linux, `export AUTOGRAPH_VERBOSITY=10`) and attach the full output.\n",
"Cause: module, class, method, function, traceback, frame, or code object was expected, got cython_function_or_method\n",
"To silence this warning, decorate the function with @tf.autograph.experimental.do_not_convert\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"The sampler InBatchSampler returned no samples for this batch.\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"16/16 [==============================] - 5s 5ms/step - loss: 6.0230 - recall_at_10: 0.3760 - mrr_at_10: 0.3760 - regularization_loss: 0.0030\n",
"Epoch 2/6\n",
"16/16 [==============================] - 0s 4ms/step - loss: 3.0718 - recall_at_10: 0.7227 - mrr_at_10: 0.7227 - regularization_loss: 0.0099\n",
"Epoch 3/6\n",
"16/16 [==============================] - 0s 5ms/step - loss: 0.6785 - recall_at_10: 0.9189 - mrr_at_10: 0.9189 - regularization_loss: 0.0178\n",
"Epoch 4/6\n",
"16/16 [==============================] - 0s 5ms/step - loss: 0.2831 - recall_at_10: 0.9824 - mrr_at_10: 0.9824 - regularization_loss: 0.0233\n",
"Epoch 5/6\n",
"16/16 [==============================] - 0s 6ms/step - loss: 0.1901 - recall_at_10: 1.0000 - mrr_at_10: 1.0000 - regularization_loss: 0.0264\n",
"Epoch 6/6\n",
"16/16 [==============================] - 0s 5ms/step - loss: 0.1457 - recall_at_10: 1.0000 - mrr_at_10: 1.0000 - regularization_loss: 0.0285\n"
]
},
{
"data": {
"text/plain": [
"<keras.callbacks.History at 0x7f82b81d4190>"
]
},
"execution_count": 13,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"model.fit(\n",
" train_dataset,\n",
" epochs=6,\n",
" batch_size=1024,\n",
" shuffle=True,\n",
" drop_last=True,\n",
" train_metrics_steps=20,\n",
")"
]
},
{
"cell_type": "markdown",
"id": "09ca1cfc",
"metadata": {},
"source": [
"We trained the model to achieve perfect performance (`mrr == 1`).\n",
"\n",
"That is what we expected with the simple rules embodied in our data.\n",
"\n",
"But that was the whole point. This is a toy model to help us familiarize ourselves with training a Two Tower model and it is much easier to wrap our head around it with data that we understand inside out."
]
},
{
"cell_type": "markdown",
"id": "6d05dc26",
"metadata": {},
"source": [
"## Perform offline batch prediction"
]
},
{
"cell_type": "markdown",
"id": "53b8441c",
"metadata": {},
"source": [
"Let us now perform offline prediction.\n",
"\n",
"We begin by creating a dataset of items. When presented with a user, our model will predict which item is the customer most likely to buy."
]
},
{
"cell_type": "code",
"execution_count": 14,
"id": "d2284c09",
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>car_make</th>\n",
" <th>car_color</th>\n",
" <th>car_id</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>13764</th>\n",
" <td>2</td>\n",
" <td>1</td>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <th>8144</th>\n",
" <td>1</td>\n",
" <td>2</td>\n",
" <td>2</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2791</th>\n",
" <td>1</td>\n",
" <td>1</td>\n",
" <td>3</td>\n",
" </tr>\n",
" <tr>\n",
" <th>16411</th>\n",
" <td>3</td>\n",
" <td>2</td>\n",
" <td>4</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>"
],
"text/plain": [
" car_make car_color car_id\n",
"13764 2 1 1\n",
"8144 1 2 2\n",
"2791 1 1 3\n",
"16411 3 2 4"
]
},
"execution_count": 14,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"item_features = train_dataset.schema.select_by_tag(Tags.ITEM).column_names\n",
"item_dataset = train_dataset.to_ddf()[item_features].drop_duplicates(subset=['car_make', 'car_color', 'car_id'], keep='last').sort_values('car_id').compute()\n",
"item_dataset"
]
},
{
"cell_type": "markdown",
"id": "f0ffba78",
"metadata": {},
"source": [
"We transform our model into a top-k recommender.\n",
"\n",
"In training, it was predicting whether a customer was likely to buy a given car (when presented with the `car_id` and car features).\n",
"\n",
"We now want to ask a different question:\n",
"\n",
"> For this particular user, with this set of characteristics, which car, out of the available cars, are they most likely to buy?"
]
},
{
"cell_type": "code",
"execution_count": 15,
"id": "6d7d3a15",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"WARNING:absl:Found untraced functions such as model_context_layer_call_fn, model_context_layer_call_and_return_conditional_losses, parallel_block_layer_call_fn, parallel_block_layer_call_and_return_conditional_losses, sequential_block_6_layer_call_fn while saving (showing 5 of 26). These functions will not be directly callable after loading.\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"INFO:tensorflow:Assets written to: /tmp/tmplf3rfudf/assets\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"INFO:tensorflow:Assets written to: /tmp/tmplf3rfudf/assets\n",
"/usr/local/lib/python3.8/dist-packages/merlin/schema/tags.py:148: UserWarning: Compound tags like Tags.ITEM_ID have been deprecated and will be removed in a future version. Please use the atomic versions of these tags, like [<Tags.ITEM: 'item'>, <Tags.ID: 'id'>].\n",
" warnings.warn(\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"WARNING:tensorflow:No training configuration found in save file, so the model was *not* compiled. Compile it manually.\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"WARNING:tensorflow:No training configuration found in save file, so the model was *not* compiled. Compile it manually.\n"
]
}
],
"source": [
"item_dataset = nvt.Dataset(item_dataset)\n",
"top_k_rec = model.to_top_k_recommender(item_dataset, 4)"
]
},
{
"cell_type": "markdown",
"id": "348db7d4",
"metadata": {},
"source": [
"Here are all the unique users in our dataset.\n",
"\n",
"Of course, I am using train data here for prediction, which is not what normally should be the case.\n",
"\n",
"But that is only in order to demonstrate the functionality in the most straightforward way possible."
]
},
{
"cell_type": "code",
"execution_count": 16,
"id": "2d45635b",
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>user_type</th>\n",
" <th>user_age</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>13656</th>\n",
" <td>1</td>\n",
" <td>-1.690613</td>\n",
" </tr>\n",
" <tr>\n",
" <th>13757</th>\n",
" <td>1</td>\n",
" <td>-1.633240</td>\n",
" </tr>\n",
" <tr>\n",
" <th>13719</th>\n",
" <td>1</td>\n",
" <td>-1.575868</td>\n",
" </tr>\n",
" <tr>\n",
" <th>13753</th>\n",
" <td>1</td>\n",
" <td>-1.518495</td>\n",
" </tr>\n",
" <tr>\n",
" <th>13722</th>\n",
" <td>1</td>\n",
" <td>-1.461122</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>"
],
"text/plain": [
" user_type user_age\n",
"13656 1 -1.690613\n",
"13757 1 -1.633240\n",
"13719 1 -1.575868\n",
"13753 1 -1.518495\n",
"13722 1 -1.461122"
]
},
"execution_count": 16,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"users_schema = train_dataset.schema.select_by_tag(Tags.USER)\n",
"user_features = users_schema.column_names\n",
"\n",
"unique_users = train_dataset.to_ddf()[user_features].drop_duplicates(subset=['user_type', 'user_age'], keep='last').compute()\n",
"unique_users.head()"
]
},
{
"cell_type": "markdown",
"id": "5edd1356",
"metadata": {},
"source": [
"We now take our users and output the predictions.\n",
"\n",
"But before we do so, let us see what the ground truth looks like."
]
},
{
"cell_type": "code",
"execution_count": 17,
"id": "0582a7a0",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/usr/local/lib/python3.8/dist-packages/merlin/schema/tags.py:148: UserWarning: Compound tags like Tags.ITEM_ID have been deprecated and will be removed in a future version. Please use the atomic versions of these tags, like [<Tags.ITEM: 'item'>, <Tags.ID: 'id'>].\n",
" warnings.warn(\n"
]
},
{
"data": {
"text/plain": [
"array([1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n",
" 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 4, 4, 4, 4,\n",
" 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 3, 3, 3, 3, 3, 3, 3,\n",
" 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 2, 2, 2, 2, 2, 2, 2, 2, 2,\n",
" 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,\n",
" 2, 2, 2, 2, 2, 2, 2, 2])"
]
},
"execution_count": 17,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"train_dataset.compute().loc[unique_users.index]['car_id'].values"
]
},
{
"cell_type": "markdown",
"id": "aac5daa6",
"metadata": {},
"source": [
"And here are the predictions from our model."
]
},
{
"cell_type": "code",
"execution_count": 18,
"id": "eed40df5",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"2/2 [==============================] - 0s 2ms/step\n"
]
}
],
"source": [
"users_dataset = nvt.Dataset(unique_users, schema=users_schema)\n",
"probs, cls_idxs = top_k_rec.predict(BatchedDataset(users_dataset, 100, shuffle=False, schema=users_schema))"
]
},
{
"cell_type": "code",
"execution_count": 19,
"id": "6ffb3f39",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"array([1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n",
" 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 4, 4, 4, 4,\n",
" 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 3, 3, 3, 3, 3, 3, 3,\n",
" 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 2, 2, 2, 2, 2, 2, 2, 2, 2,\n",
" 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,\n",
" 2, 2, 2, 2, 2, 2, 2, 2])"
]
},
"execution_count": 19,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"cls_idxs[:, 0]"
]
},
{
"cell_type": "markdown",
"id": "ed03262c",
"metadata": {},
"source": [
"Not surprisingly, the ground truth and the predictions match!\n",
"\n",
"We have demonstrated that the model functions as expect on our toy data."
]
},
{
"cell_type": "markdown",
"id": "8b0cd9c7",
"metadata": {},
"source": [
"## Summary"
]
},
{
"cell_type": "markdown",
"id": "4f5f01b8",
"metadata": {},
"source": [
"The Two Tower model has a lot of interesting characteristics (it doesn't interact user and item data until the very end, can make use of continuous and categorical features, etc). It is an important architecture and one that can be used across a vast set of scenarios.\n",
"\n",
"If you would like to learn more about Two Tower and how it can be used, consider reading the following blog post: [Scale faster with less code using Two Tower with Merlin](https://medium.com/nvidia-merlin/scale-faster-with-less-code-using-two-tower-with-merlin-c16f32aafa9f).\n",
"\n",
"If you would rather jump straight into the code, there are two tutorials I would recommend:\n",
"\n",
"* [Two-Stage Recommender Systems](https://github.com/NVIDIA-Merlin/models/blob/main/examples/05-Retrieval-Model.ipynb)\n",
" * learn how to train various retrieval models, Two Tower is one of the architectures that are discussed\n",
"* [Deploying a Multi-Stage Recommender System](https://github.com/NVIDIA-Merlin/Merlin/tree/main/examples/Building-and-deploying-multi-stage-RecSys)\n",
" * how to deploy an ensemble of models in production with Two Tower being one of them (note: you could only deploy the Two Tower model, without the additional functionality, if that is what you might want to do)"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.10"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment