Created
January 3, 2024 13:13
-
-
Save manishmshiva/7eac76db7b1e2a48892726cf5dea788b to your computer and use it in GitHub Desktop.
K-means clustering example
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from sklearn.cluster import KMeans | |
import matplotlib.pyplot as plt | |
import numpy as np | |
# Step 1: Generate random data | |
np.random.seed(0) | |
x = -2 * np.random.rand(100, 2) # Generate random points around (-2, -2) | |
x1 = 1 + 2 * np.random.rand(50, 2) # Generate random points around (3, 3) | |
x[50:100, :] = x1 # Combine the two sets of points | |
# Step 2: Visualize the data (unclustered) | |
plt.scatter(x[:, 0], x[:, 1], s=50, c='b') | |
plt.xlabel('X Axis') | |
plt.ylabel('Y Axis') | |
plt.show() | |
# Step 3: Apply KMeans clustering | |
kmeans = KMeans(n_clusters=2) # Initialize KMeans with 2 clusters | |
kmeans.fit(x) # Fit the model to the data | |
# Step 4: Get the coordinates of the cluster centers and the cluster labels | |
centroids = kmeans.cluster_centers_ # Centroids of the clusters | |
labels = kmeans.labels_ # Labels of each point | |
# Step 5: Visualize the clustered data | |
plt.scatter(x[:, 0], x[:, 1], s=50, c=labels, cmap='viridis') # Plot data points with cluster color | |
plt.scatter(centroids[:, 0], centroids[:, 1], s=200, c='red', marker='*') # Plot centroids | |
plt.xlabel('X Axis') | |
plt.ylabel('Y Axis') | |
plt.title('K-Means Clustering') | |
plt.show() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment