Created
December 27, 2023 05:14
-
-
Save smzn/b81a6c1bd11fe6f0726020f14dd3bf62 to your computer and use it in GitHub Desktop.
Clustering
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from sklearn.cluster import KMeans | |
# Clustering | |
# Number of clusters | |
n_clusters = 4 | |
kmeans = KMeans(n_clusters=n_clusters, random_state=0) | |
clusters = kmeans.fit_predict(combined_top_items_daily) | |
# Creating a DataFrame for the PCA components | |
pca_df = pd.DataFrame(principal_components_combined, columns=[f'PC{i+1}' for i in range(n_components_combined)]) | |
pca_df['Cluster'] = clusters | |
# Explained variance | |
explained_variance = pca_combined.explained_variance_ratio_ | |
cumulative_explained_variance = np.cumsum(explained_variance) | |
def biplot_with_shapes_and_legend(score, coeff, clusters, labels=None): | |
xs = score[:,0] | |
ys = score[:,1] | |
n = coeff.shape[0] # Number of components | |
scalex = 1.0/(xs.max() - xs.min()) | |
scaley = 1.0/(ys.max() - ys.min()) | |
# Scatter plot with different shapes for each cluster | |
shapes = ['o', 's', '^', 'D'] # Different shapes for each cluster | |
for cluster in range(n_clusters): | |
plt.scatter(xs[clusters == cluster] * scalex, ys[clusters == cluster] * scaley, | |
c=np.random.rand(3,), marker=shapes[cluster], label=f'Cluster {cluster + 1}') | |
# Plotting arrows and labels for each principal component | |
for i in range(n): | |
plt.arrow(0, 0, coeff[i,0], coeff[i,1], color='r', alpha=0.5) | |
if labels is not None and i < len(labels): | |
plt.text(coeff[i,0]* 1.15, coeff[i,1] * 1.15, labels[i], color='g', ha='center', va='center') | |
# Adding a legend | |
plt.legend() | |
# Plotting the modified biplot | |
plt.figure(figsize=(12, 8)) | |
biplot_with_shapes_and_legend(principal_components_combined, np.transpose(pca_combined.components_[0:2, :]), clusters, labels=daily_top_items_sales_with_total.columns) | |
plt.xlabel('PC1') | |
plt.ylabel('PC2') | |
plt.title('Biplot of PCA (Including Total Daily Sales) with Clusters') | |
plt.show() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment