Skip to content

Instantly share code, notes, and snippets.

@aliciawilliams
Last active October 6, 2023 18:22
Show Gist options
  • Save aliciawilliams/85f6b7c355eba9efa70eec7f5d27c839 to your computer and use it in GitHub Desktop.
Save aliciawilliams/85f6b7c355eba9efa70eec7f5d27c839 to your computer and use it in GitHub Desktop.
Customer Segmentation Demo Notebook
Display the source blob
Display the rendered blob
Raw
{
"nbformat": 4,
"nbformat_minor": 0,
"metadata": {
"colab": {
"provenance": []
},
"kernelspec": {
"name": "python3",
"display_name": "Python 3"
},
"language_info": {
"name": "python"
}
},
"cells": [
{
"cell_type": "code",
"source": [
"# Copyright 2023 Google LLC\n",
"#\n",
"# Licensed under the Apache License, Version 2.0 (the \"License\");\n",
"# you may not use this file except in compliance with the License.\n",
"# You may obtain a copy of the License at\n",
"#\n",
"# https://www.apache.org/licenses/LICENSE-2.0\n",
"#\n",
"# Unless required by applicable law or agreed to in writing, software\n",
"# distributed under the License is distributed on an \"AS IS\" BASIS,\n",
"# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
"# See the License for the specific language governing permissions and\n",
"# limitations under the License."
],
"metadata": {
"id": "rlporafx8Odu"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"source": [
"# Using the Vertex AI PaLM API to explain BQML Clustering"
],
"metadata": {
"id": "Sxq-V_jV4gEk"
}
},
{
"cell_type": "markdown",
"source": [
"This sample demostrates how to use the Vertex AI PaLM API to explain BQML clustering.\n",
"\n",
"In order to run the notebook in your own project, please ensure you create a dataset named `cluster_demo`.\n",
"\n",
"You can upload this notebook into BigQuery Studio following [these instructions](https://cloud.google.com/bigquery/docs/create-notebooks#upload_notebooks)."
],
"metadata": {
"id": "sHoUOuxf5GOR"
}
},
{
"cell_type": "markdown",
"source": [
"## Create a K-means model to cluster ecommerce data\n",
"\n",
"First let's look at our data quickly before we create the model."
],
"metadata": {
"id": "WnXQQOAc4t22"
}
},
{
"cell_type": "code",
"source": [
"%%bigquery\n",
"SELECT\n",
" user_id,\n",
" order_id,\n",
" sale_price,\n",
" created_at as order_created_date\n",
"FROM `bigquery-public-data.thelook_ecommerce.order_items`\n",
"WHERE created_at BETWEEN CAST('2020-01-01 00:00:00' AS TIMESTAMP)\n",
"AND CAST('2023-01-01 00:00:00' AS TIMESTAMP)"
],
"metadata": {
"id": "V6-JiJZBGmO7"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"source": [
"A common method for customer segmentation is the Recency, Frequency, Monetary Value (RFM) model. Let's perform some SQL transformations to calculate these metrics on a customer (`user_id`) basis and save them to a new table called `customer_stats`."
],
"metadata": {
"id": "ccCMFyWkJ4aP"
}
},
{
"cell_type": "code",
"source": [
"%%bigquery\n",
"CREATE OR REPLACE TABLE cluster_demo.customer_stats AS\n",
"SELECT\n",
" user_id,\n",
" DATE_DIFF(CURRENT_DATE(), CAST(MAX(order_created_date) AS DATE), day) AS days_since_last_order, ---RECENCY\n",
" COUNT(order_id) AS count_orders, --FREQUENCY\n",
" AVG(sale_price) AS average_spend --MONETARY\n",
" FROM (\n",
" SELECT\n",
" user_id,\n",
" order_id,\n",
" sale_price,\n",
" created_at as order_created_date\n",
"\n",
" FROM `bigquery-public-data.thelook_ecommerce.order_items`\n",
" WHERE\n",
" created_at\n",
" BETWEEN CAST('2020-01-01 00:00:00' AS TIMESTAMP)\n",
" AND CAST('2023-01-01 00:00:00' AS TIMESTAMP)\n",
" )\n",
"GROUP BY user_id;"
],
"metadata": {
"id": "mvyauyewOkxw"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"source": [
"Now, let's take a look at the transformed data."
],
"metadata": {
"id": "GRLAsq78KJoU"
}
},
{
"cell_type": "code",
"source": [
"%%bigquery\n",
"SELECT * FROM cluster_demo.customer_stats"
],
"metadata": {
"id": "JSWi6zsCO0lp"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"source": [
"We can now create our customer segmentation model using BQML with the `CREATE MODEL` statement. In the options we can set the model type to K-means.\n",
"\n",
"In the query below, we are purposefully having the model create 5 separate clusters. The number of clusters can change depending on your input, and how closely you want your model to fit your data. We will not cover hyperparameter tuning in this post, but you can read more on how to identify the proper amount of clusters [here](https://cloud.google.com/bigquery/docs/reference/standard-sql/bigqueryml-hyperparameter-tuning).\n"
],
"metadata": {
"id": "LoYUdaDdKa-y"
}
},
{
"cell_type": "code",
"source": [
"%%bigquery\n",
"CREATE OR REPLACE MODEL `cluster_demo.ecommerce_customer_segment_cluster5`\n",
" OPTIONS (\n",
" MODEL_TYPE = \"KMEANS\",\n",
" NUM_CLUSTERS = 5,\n",
" KMEANS_INIT_METHOD = \"KMEANS++\",\n",
" STANDARDIZE_FEATURES = TRUE )\n",
"AS (\n",
"SELECT * EXCEPT (user_id) FROM cluster_demo.customer_stats\n",
" )"
],
"metadata": {
"id": "n2APzYZjPO0R"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"source": [
"Once the model finishes training, we can use `ML.EVALUATE` to take a look at the model's clustering performance through these metrics: [Davies Bouldin Index](https://en.wikipedia.org/wiki/Davies%E2%80%93Bouldin_index) and Mean Squared Distance"
],
"metadata": {
"id": "mYN7F_f3PeE5"
}
},
{
"cell_type": "code",
"source": [
"%%bigquery\n",
"SELECT *\n",
"FROM ML.EVALUATE(MODEL cluster_demo.ecommerce_customer_segment_cluster5)\n"
],
"metadata": {
"id": "wDy9-26SFnuS"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"source": [
"Now let's take our transformed dataset, and run predictions using our new model to determine each user's cluster or segment. We can run predictions with the `ML.PREDICT` function."
],
"metadata": {
"id": "vDDtQvi9LUWc"
}
},
{
"cell_type": "code",
"source": [
"%%bigquery\n",
"CREATE OR REPLACE TABLE cluster_demo.segment_results AS\n",
"SELECT\n",
" * EXCEPT(nearest_centroids_distance)\n",
"FROM\n",
" ML.PREDICT( MODEL cluster_demo.ecommerce_customer_segment_cluster5,\n",
" (\n",
" SELECT\n",
" *\n",
" FROM\n",
" cluster_demo.customer_stats\n",
"))"
],
"metadata": {
"id": "3cuLNDcMBfTC"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"source": [
"The previous SQL query stored the results as a new table; let's take a look at them. The `CENTROID_ID` indicates which of the 5 segments that particular user is predicted to fall closest to."
],
"metadata": {
"id": "loYpisZgLxBY"
}
},
{
"cell_type": "code",
"source": [
"%%bigquery\n",
"SELECT * FROM cluster_demo.segment_results"
],
"metadata": {
"id": "7-qSG3tyD9gr"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"source": [
"Let's plot each of the customers and their predicted segments (by color) against their `average_spend` and `days_since_last_order` attributes.\n",
"\n",
"While this is only a 2-dimensional representation and leaves our the 3rd attribute/dimension of `count_orders`, it is useful to see a partial view of the segments."
],
"metadata": {
"id": "a9r_lDCwMHjh"
}
},
{
"cell_type": "code",
"source": [
"import pandas as pd\n",
"import bigframes.pandas as bpd\n",
"\n",
"segments_df = bpd.read_gbq(\"cluster_demo.segment_results\")\n",
"segments_df = segments_df.rename(columns={'CENTROID_ID': 'segment_id'})\n",
"segments_pd = segments_df.sample(frac=0.2, random_state=1).to_pandas()\n",
"\n",
"# Generate scatterplot to display clusters\n",
"import seaborn as sns\n",
"import matplotlib.pyplot as plt\n",
"\n",
"g = sns.lmplot(x='days_since_last_order', y='average_spend', data=segments_pd, fit_reg=False, hue='segment_id', palette='Set2', height=7, aspect=2)\n",
"g = (g.set_axis_labels('Days Since Last Order','Average Spend ($)', fontsize=15).set(xlim=(200,1000),ylim=(0,400)))\n",
"plt.title('Attribute Grouped by K-means Cluster', pad=20, fontsize=20)"
],
"metadata": {
"id": "bbIEZzxvFNjH"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"source": [
"Now let's get the cluster (centroid) information"
],
"metadata": {
"id": "t1RNh1NqPpJ2"
}
},
{
"cell_type": "code",
"source": [
"# TODO: Update YOUR-PROJECT-ID-HERE in next line to your own project ID.\n",
"%%bigquery centroid_df --project YOUR-PROJECT-ID-HERE\n",
"SELECT\n",
" CONCAT('cluster ', CAST(centroid_id as STRING)) as centroid,\n",
" average_spend,\n",
" count_orders as count_of_orders,\n",
" days_since_last_order\n",
"FROM (\n",
" SELECT centroid_id, feature, ROUND(numerical_value, 2) as value\n",
" FROM ML.CENTROIDS(MODEL cluster_demo.ecommerce_customer_segment_cluster5)\n",
")\n",
"PIVOT (\n",
" SUM(value)\n",
" FOR feature IN ('average_spend', 'count_orders', 'days_since_last_order')\n",
")\n",
"ORDER BY centroid_id"
],
"metadata": {
"id": "H2czyn5zPvhJ"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"centroid_df.head()"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 206
},
"id": "ajnK6MvdwfuL",
"executionInfo": {
"status": "ok",
"timestamp": 1696382011272,
"user_tz": 420,
"elapsed": 217,
"user": {
"displayName": "",
"userId": ""
}
},
"outputId": "f83fd94c-de9e-488c-82e0-e880a5e14454"
},
"execution_count": 23,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
" centroid average_spend count_of_orders days_since_last_order\n",
"0 cluster 1 43.74 1.34 475.69\n",
"1 cluster 2 503.62 1.34 601.16\n",
"2 cluster 3 169.92 1.41 589.70\n",
"3 cluster 4 47.75 1.52 1013.93\n",
"4 cluster 5 56.60 4.06 493.77"
],
"text/html": [
"\n",
" <div id=\"df-92af4e38-2f4a-4411-ab6d-3e4f98bd17a1\" class=\"colab-df-container\">\n",
" <div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>centroid</th>\n",
" <th>average_spend</th>\n",
" <th>count_of_orders</th>\n",
" <th>days_since_last_order</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>cluster 1</td>\n",
" <td>43.74</td>\n",
" <td>1.34</td>\n",
" <td>475.69</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1</th>\n",
" <td>cluster 2</td>\n",
" <td>503.62</td>\n",
" <td>1.34</td>\n",
" <td>601.16</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2</th>\n",
" <td>cluster 3</td>\n",
" <td>169.92</td>\n",
" <td>1.41</td>\n",
" <td>589.70</td>\n",
" </tr>\n",
" <tr>\n",
" <th>3</th>\n",
" <td>cluster 4</td>\n",
" <td>47.75</td>\n",
" <td>1.52</td>\n",
" <td>1013.93</td>\n",
" </tr>\n",
" <tr>\n",
" <th>4</th>\n",
" <td>cluster 5</td>\n",
" <td>56.60</td>\n",
" <td>4.06</td>\n",
" <td>493.77</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>\n",
" <div class=\"colab-df-buttons\">\n",
" \n",
" <div class=\"colab-df-container\">\n",
" <button class=\"colab-df-convert\" onclick=\"convertToInteractive('df-92af4e38-2f4a-4411-ab6d-3e4f98bd17a1')\"\n",
" title=\"Convert this dataframe to an interactive table.\"\n",
" style=\"display:none;\">\n",
" \n",
" <svg xmlns=\"http://www.w3.org/2000/svg\" height=\"24px\" viewBox=\"0 -960 960 960\">\n",
" <path d=\"M120-120v-720h720v720H120Zm60-500h600v-160H180v160Zm220 220h160v-160H400v160Zm0 220h160v-160H400v160ZM180-400h160v-160H180v160Zm440 0h160v-160H620v160ZM180-180h160v-160H180v160Zm440 0h160v-160H620v160Z\"/>\n",
" </svg>\n",
" </button>\n",
" \n",
" <style>\n",
" .colab-df-container {\n",
" display:flex;\n",
" gap: 12px;\n",
" }\n",
"\n",
" .colab-df-convert {\n",
" background-color: #E8F0FE;\n",
" border: none;\n",
" border-radius: 50%;\n",
" cursor: pointer;\n",
" display: none;\n",
" fill: #1967D2;\n",
" height: 32px;\n",
" padding: 0 0 0 0;\n",
" width: 32px;\n",
" }\n",
"\n",
" .colab-df-convert:hover {\n",
" background-color: #E2EBFA;\n",
" box-shadow: 0px 1px 2px rgba(60, 64, 67, 0.3), 0px 1px 3px 1px rgba(60, 64, 67, 0.15);\n",
" fill: #174EA6;\n",
" }\n",
"\n",
" .colab-df-buttons div {\n",
" margin-bottom: 4px;\n",
" }\n",
"\n",
" [theme=dark] .colab-df-convert {\n",
" background-color: #3B4455;\n",
" fill: #D2E3FC;\n",
" }\n",
"\n",
" [theme=dark] .colab-df-convert:hover {\n",
" background-color: #434B5C;\n",
" box-shadow: 0px 1px 3px 1px rgba(0, 0, 0, 0.15);\n",
" filter: drop-shadow(0px 1px 2px rgba(0, 0, 0, 0.3));\n",
" fill: #FFFFFF;\n",
" }\n",
" </style>\n",
"\n",
" <script>\n",
" const buttonEl =\n",
" document.querySelector('#df-92af4e38-2f4a-4411-ab6d-3e4f98bd17a1 button.colab-df-convert');\n",
" buttonEl.style.display =\n",
" google.colab.kernel.accessAllowed ? 'block' : 'none';\n",
"\n",
" async function convertToInteractive(key) {\n",
" const element = document.querySelector('#df-92af4e38-2f4a-4411-ab6d-3e4f98bd17a1');\n",
" const dataTable =\n",
" await google.colab.kernel.invokeFunction('convertToInteractive',\n",
" [key], {});\n",
" if (!dataTable) return;\n",
"\n",
" const docLinkHtml = 'Like what you see? Visit the ' +\n",
" '<a target=\"_blank\" href=https://colab.research.google.com/notebooks/data_table.ipynb>data table notebook</a>'\n",
" + ' to learn more about interactive tables.';\n",
" element.innerHTML = '';\n",
" dataTable['output_type'] = 'display_data';\n",
" await google.colab.output.renderOutput(dataTable, element);\n",
" const docLink = document.createElement('div');\n",
" docLink.innerHTML = docLinkHtml;\n",
" element.appendChild(docLink);\n",
" }\n",
" </script>\n",
" </div>\n",
" \n",
" </div>\n",
" </div>\n",
" "
]
},
"metadata": {},
"execution_count": 23
}
]
},
{
"cell_type": "markdown",
"source": [
"Whew! That's a lot of metrics and cluster info. How about we explain this to our colleagues using the magic of LLMs."
],
"metadata": {
"id": "Xhha_o_qPwJY"
}
},
{
"cell_type": "code",
"source": [
"centroid_df.to_string(header=False, index=False)\n",
"\n",
"cluster_info = []\n",
"for i, row in centroid_df.iterrows():\n",
" cluster_info.append(\"{0}, average spend ${2}, count of orders per person {1}, days since last order {3}\"\n",
" .format(row[\"centroid\"], row[\"count_of_orders\"], row[\"average_spend\"], row[\"days_since_last_order\"]) )\n",
"\n",
"print(str.join(\"\\n\", cluster_info))"
],
"metadata": {
"id": "AWzyZUFdRQll"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"source": [
"## Explain with Vertex AI PaLM API"
],
"metadata": {
"id": "7Na0_rVVLhOk"
}
},
{
"cell_type": "markdown",
"source": [
"The PaLM 2 large language model is a convenient and powerful language model that can help us explain our data to multiple audiences.\n",
"\n",
"In this case, we’re going to use the PaLM API to describe each customer segment and determine a good next marketing action to that segment."
],
"metadata": {
"id": "tdoRB7RjLso6"
}
},
{
"cell_type": "code",
"source": [
"from vertexai.language_models._language_models import TextGenerationModel\n",
"\n",
"model = TextGenerationModel.from_pretrained(\"text-bison@001\")\n",
"\n",
"clusters = str.join(\"\\n\", cluster_info)\n",
"\n",
"prompt = f\"\"\"\n",
"You're a creative brand strategist, given the following clusters, come up with creative brand persona, a catchy title, and next marketing action, explained step by step.\n",
"\n",
"Clusters:\n",
"{clusters}\n",
"\n",
"For each Cluster:\n",
"* Title:\n",
"* Persona:\n",
"* Next Marketing Step:\n",
"\"\"\"\n",
"\n",
"print(model.predict(\n",
" prompt,\n",
" max_output_tokens=1024,\n",
" temperature=0.55,\n",
" top_p=0.8,\n",
" top_k=40,\n",
"))"
],
"metadata": {
"id": "MYwQNH0ELu9y"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"source": [
"Voila! We've now used k-means clustering to create groups of spenders and explain their profiles."
],
"metadata": {
"id": "jqC5uUguaunY"
}
},
{
"cell_type": "markdown",
"source": [
"Sometimes, though, you want a little bit [extra](https://cloud.google.com/blog/transform/prompt-debunking-five-generative-ai-misconceptions)."
],
"metadata": {
"id": "5HEVHGmmRyGW"
}
},
{
"cell_type": "code",
"source": [
"from vertexai.language_models._language_models import TextGenerationModel\n",
"\n",
"model = TextGenerationModel.from_pretrained(\"text-bison@001\")\n",
"\n",
"cluster_info = str.join('\\n', cluster_info)\n",
"\n",
"prompt = f\"\"\"\n",
"Pretend you're a creative strategist, analyse the following clusters and come up with \\\n",
"creative brand persona for each that includes the detail of which Taylor Swift song is \\\n",
"likely to be their favorite, a summary of how this relates to their purchasing behavior, \\\n",
"and a witty e-mail headline for marketing campaign targeted to their group.\n",
"\n",
"Clusters:\n",
"{clusters}\n",
"\"\"\"\n",
"\n",
"print(model.predict(\n",
" prompt,\n",
" max_output_tokens=1024,\n",
" temperature=0.45,\n",
" top_p=0.8, top_k=40,\n",
"))"
],
"metadata": {
"id": "3HgwtN1tR3xQ"
},
"execution_count": null,
"outputs": []
}
]
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment