Skip to content

Instantly share code, notes, and snippets.

@koshian2
Created November 13, 2018 11:26
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save koshian2/9075d991c291caf1b2a45b0c3c416c7e to your computer and use it in GitHub Desktop.
Save koshian2/9075d991c291caf1b2a45b0c3c416c7e to your computer and use it in GitHub Desktop.
latent features in u-net
from sklearn.cluster import KMeans
import numpy as np
from keras.datasets import mnist
import os, tarfile, datetime
def kmeans(path, state):
if path == "original":
(latent, gt), (_, _) = mnist.load_data()
latent = (latent / 255.0).reshape(-1, 784)
gt = np.ravel(gt)
else:
data = np.load(path)
latent = data["latent"]
gt = np.ravel(data["ground_truth"])
km = KMeans(n_clusters=10, random_state=state)
km_label = km.fit_predict(latent)
# row:kmeans, col:ground_truth
matrix = np.zeros((10, 10), np.int32)
for i in range(60000):
row = km_label[i]
col = gt[i]
matrix[row, col] += 1
# purity
row_purity = np.max(matrix, axis=-1) / np.sum(matrix, axis=-1)
weights = np.sum(matrix, axis=-1) / np.sum(matrix)
total_purity = np.sum(row_purity * weights)
return total_purity
def kmeans_multiple(path):
results = np.zeros(10)
print("Starts...", path)
for i in range(10):
print("... i = ", i, datetime.datetime.now())
results[i] = kmeans(path, i)
print(path)
print(results)
print(np.mean(results))
if not os.path.exists("results"):
os.mkdir("results")
np.savez("results/"+path.replace(".npz", ""), results=results)
def kmeans_all():
paths = ["original",
"latent_skip_False_bottleneck_False.npz",
"latent_skip_False_bottleneck_True.npz",
"latent_skip_True_bottleneck_False.npz",
"latent_skip_True_bottleneck_True.npz"]
for p in paths:
kmeans_multiple(p)
with tarfile.open("results.tar.gz", mode="w:gz") as tar:
tar.add("results")
kmeans_all()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment