Skip to content

Instantly share code, notes, and snippets.

@thuwarakeshm
Last active June 16, 2022 00:21
Show Gist options
  • Save thuwarakeshm/5b98e93797fbf993cf47dba0bb987c74 to your computer and use it in GitHub Desktop.
Save thuwarakeshm/5b98e93797fbf993cf47dba0bb987c74 to your computer and use it in GitHub Desktop.
---
title: Create K-Means clusters
description: Specify title and number of clusters and get a plot of clustered dataset
show-code: False
format:
    theme: moon
params:
    title:
        input: text
        label: Set a title for your project
    n_clusters:
        input: slider
        label: How many clusters to have?
        value: 3
        min: 0
        max: 10
---
n_clusters=3
title="Chart title"
import pandas as pd
from sklearn.cluster import KMeans
import matplotlib.pyplot as plt
import seaborn as sns
df = pd.read_csv(
    "https://raw.githubusercontent.com/ThuwarakeshM/PracticalML-KMeans-Election/master/voters_demo_sample.csv"
)
def run_kmeans(df, n_clusters=2):
    kmeans = KMeans(n_clusters, random_state=0).fit(df[["Age", "Income"]])

    fig, ax = plt.subplots(figsize=(16, 9))

    ax.grid(False)
    ax.set_facecolor("#FFF")
    ax.spines[["left", "bottom"]].set_visible(True)
    ax.spines[["left", "bottom"]].set_color("#4a4a4a")
    ax.tick_params(labelcolor="#4a4a4a")
    ax.yaxis.label.set(color="#4a4a4a", fontsize=20)
    ax.xaxis.label.set(color="#4a4a4a", fontsize=20)
    # --------------------------------------------------

    # Create scatterplot
    ax = sns.scatterplot(
        ax=ax,
        x=df.Age,
        y=df.Income,
        hue=kmeans.labels_,
        palette=sns.color_palette("colorblind", n_colors=n_clusters),
        legend=None,
    )
    
    ax.set_title(title)

    # Annotate cluster centroids
    for ix, [age, income] in enumerate(kmeans.cluster_centers_):
        ax.scatter(age, income, s=200, c="#a8323e")
        ax.annotate(
            f"Cluster #{ix+1}",
            (age, income),
            fontsize=25,
            color="#a8323e",
            xytext=(age + 5, income + 3),
            bbox=dict(boxstyle="round,pad=0.3", fc="white", ec="#a8323e", lw=2),
            ha="center",
            va="center",
        )
run_kmeans(df, n_clusters=n_clusters)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment