Skip to content

Instantly share code, notes, and snippets.

@manishmshiva
Created January 3, 2024 13:13
Show Gist options
  • Save manishmshiva/7eac76db7b1e2a48892726cf5dea788b to your computer and use it in GitHub Desktop.
Save manishmshiva/7eac76db7b1e2a48892726cf5dea788b to your computer and use it in GitHub Desktop.
K-means clustering example
from sklearn.cluster import KMeans
import matplotlib.pyplot as plt
import numpy as np
# Step 1: Generate random data
np.random.seed(0)
x = -2 * np.random.rand(100, 2) # Generate random points around (-2, -2)
x1 = 1 + 2 * np.random.rand(50, 2) # Generate random points around (3, 3)
x[50:100, :] = x1 # Combine the two sets of points
# Step 2: Visualize the data (unclustered)
plt.scatter(x[:, 0], x[:, 1], s=50, c='b')
plt.xlabel('X Axis')
plt.ylabel('Y Axis')
plt.show()
# Step 3: Apply KMeans clustering
kmeans = KMeans(n_clusters=2) # Initialize KMeans with 2 clusters
kmeans.fit(x) # Fit the model to the data
# Step 4: Get the coordinates of the cluster centers and the cluster labels
centroids = kmeans.cluster_centers_ # Centroids of the clusters
labels = kmeans.labels_ # Labels of each point
# Step 5: Visualize the clustered data
plt.scatter(x[:, 0], x[:, 1], s=50, c=labels, cmap='viridis') # Plot data points with cluster color
plt.scatter(centroids[:, 0], centroids[:, 1], s=200, c='red', marker='*') # Plot centroids
plt.xlabel('X Axis')
plt.ylabel('Y Axis')
plt.title('K-Means Clustering')
plt.show()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment