Last active
August 22, 2019 12:59
-
-
Save xoraus/1dfe2807568357726d29666730299f18 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Initialize and fit KMeans algorithm | |
kmeans = MiniBatchKMeans(n_clusters = 36) | |
kmeans.fit(X) | |
# record centroid values | |
centroids = kmeans.cluster_centers_ | |
# reshape centroids into images | |
images = centroids.reshape(36, 28, 28) | |
images *= 255 | |
images = images.astype(np.uint8) | |
# determine cluster labels | |
cluster_labels = infer_cluster_labels(kmeans, Y) | |
# create figure with subplots using matplotlib.pyplot | |
fig, axs = plt.subplots(6, 6, figsize = (20, 20)) | |
plt.gray() | |
# loop through subplots and add centroid images | |
for i, ax in enumerate(axs.flat): | |
# determine inferred label using cluster_labels dictionary | |
for key, value in cluster_labels.items(): | |
if i in value: | |
ax.set_title('Inferred Label: {}'.format(key)) | |
# add image to subplot | |
ax.matshow(images[i]) | |
ax.axis('off') | |
# display the figure | |
fig.show() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment