Skip to content

Instantly share code, notes, and snippets.

@thuwarakeshm
Last active July 3, 2021 06:33
Show Gist options
  • Save thuwarakeshm/e520f9e0e29a5c2ba65e8503933d86e8 to your computer and use it in GitHub Desktop.
Save thuwarakeshm/e520f9e0e29a5c2ba65e8503933d86e8 to your computer and use it in GitHub Desktop.
K means clustering for election campaigns
# Performing K-Means to find 3 clusters using the three variables
kmeans = KMeans(n_clusters=3, random_state=0).fit(df[["Age", "Income", "Debt"]])
errors = [] # Create an empty list to collect inertias
# Loop through some desirable number of clusters
# The KMeans algorithm's fit method returns a property called inertia_ that has the information we need
for k in range(2, 10):
errors.append(
{
"inertia": KMeans(n_clusters=k, random_state=0)
.fit(df[["Age", "Income", "Debt"]])
.inertia_,
"num_clusters": k,
}
)
# for convenience convert the list to a pandas dataframe
df_inertia = pd.DataFrame(errors)
# Create a line plot of inertia against number of clusters
sns.lineplot(
x=df_inertia.num_clusters,
y=df_inertia.inertia,
)
# Imports
import pandas as pd # Pandas for reading data
from sklearn.cluster import KMeans # KMeans Clustering itself
# Read dataset from file using pandas
df = pd.read_csv("./voters_demo_sample.csv")
# perform K-Means clustering to find 2 clusters considering only age and income of voters
kmeans = KMeans(n_clusters=2, random_state=0).fit(df[["Age", "Income"]])
# See the cluster label for each data point / Group label of each voter
kmeans.labels_
# Identifying the (final) cluster centroids
kmeans.cluster_centers_
#Output
# array([[56.05275779, 39.92805755],
# [30.55364807, 22.2360515 ]])
# Performing K-Means to find 3 clusters using the same two variables
kmeans = KMeans(n_clusters=3, random_state=0).fit(df[["Age", "Income"]])
ax = sns.scatterplot(
x=df.Age,
y=df.Income,
hue=kmeans.labels_,
palette=sns.color_palette("colorblind", n_colors=3),
legend=None,
)
# Imports for visualiation
import seaborn as sns
# Create scatterplot
ax = sns.scatterplot(
x=df.Age,
y=df.Income,
hue=kmeans.labels_,
palette=sns.color_palette("colorblind", n_colors=2),
legend=None,
)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment