Last active
July 22, 2019 20:10
-
-
Save warneracw21/dc3b718ab08bec59380fc0533c8c28d6 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Imports | |
from __future__ import absolute_import | |
from __future__ import print_function | |
from __future__ import division | |
import tensorflow as tf | |
import pandas as pd | |
import numpy as np | |
from sklearn.datasets import make_blobs | |
import matplotlib.pyplot as plt | |
plt.style.use('fivethirtyeight') | |
# Define Constants | |
N_SAMPLES = int(1e4) | |
N_CLUSTERS = 3 | |
N_DIMENSIONS = 4 | |
BATCH_SIZE = 512 | |
# Create Mock Dataset | |
X, y = make_blobs(n_samples=N_SAMPLES, n_features=N_DIMENSIONS, centers=N_CLUSTERS) | |
# Calculate Initial Centers by Sampling from Dataset | |
inds = np.random.randint(N_SAMPLES, size=N_CLUSTERS) | |
centers = X[inds,:] | |
# Define Graph | |
graph = tf.Graph() | |
with graph.as_default(): | |
with tf.variable_scope('ClusteringModel'): | |
# Define the Input Layer | |
with tf.variable_scope('cluster_input'): | |
next_batch = tf.placeholder( | |
dtype=tf.float32, | |
shape=[None, N_DIMENSIONS], | |
name='input') | |
# Define Centroids [row vectors] | |
with tf.variable_scope('centroids', reuse=tf.AUTO_REUSE): | |
centroids = [] | |
for i in range(N_CLUSTERS): | |
centroids.append( | |
tf.get_variable( | |
name='mean_%d' % i, | |
shape=[N_DIMENSIONS], | |
dtype=tf.float32, | |
initializer=tf.initializers.constant(value=centers[i]))) | |
# Define the Covariance Mateix for the Mahalanobis Distance | |
with tf.variable_scope('covariances', reuse=tf.AUTO_REUSE): | |
covariances = [] | |
for i in range(N_CLUSTERS): | |
covariances.append( | |
tf.constant( | |
value=np.eye(N=N_DIMENSIONS), | |
name='covariance_%d' % i, | |
shape=[N_DIMENSIONS, N_DIMENSIONS], | |
dtype=tf.float32)) | |
# Zip the Centroid Vectors and Covariance Matrices into Clusters | |
clusters = zip(centroids, covariances) | |
# Normalize the Batch Vectors | |
with tf.variable_scope('normalization'): | |
batch_mean, batch_variance = tf.nn.moments( | |
next_batch, | |
axes=[0], | |
name='batch_moments') | |
normalized_batch = tf.nn.batch_normalization( | |
x=next_batch, | |
mean=batch_mean, | |
variance=batch_variance, | |
offset=0.0, | |
scale=0.99, | |
variance_epsilon=(1e-6)) | |
# Define the Operations for Calculating Cluster Distances | |
with tf.variable_scope('distances'): | |
distances = [] | |
for i, (centroid, covariance) in enumerate(clusters): | |
# Normalize the Centroid | |
centroid_norm = tf.cast( | |
tf.linalg.norm( | |
centroid, | |
axis=-1, | |
keepdims=True, | |
name='uncasted_centroid_%d_norm' % i), | |
name='centroid_%d_norm' % i, | |
dtype=tf.float32) | |
normalized_centroid = tf.divide(centroid, centroid_norm, name='normalized_centroid') | |
# Calculate the distance between the batch and the normalized centroid | |
centroid_diff = tf.subtract(normalized_batch, normalized_centroid, name='centr_diff_%d' % i) | |
dist = tf.linalg.tensor_diag_part( | |
tf.tensordot( | |
a=tf.tensordot( | |
a=centroid_diff, | |
b=covariance, | |
axes=[-1,0], | |
name='Mahalanobis_dist_%d_part1' % i), | |
b=tf.transpose(centroid_diff), | |
axes=[-1,0], | |
name='Mahalanobis_dist_%d_part2' % i), | |
name='distances_%d' % i) | |
distances.append(dist) | |
# Stack the Distances and Find the argmin | |
with tf.variable_scope('cluster_assignments'): | |
distance_vectors = tf.stack(values=distances, axis=-1, name='stack_distances') | |
cluster_assignments = tf.argmin(distance_vectors, axis=-1, name='cluster_assignments') | |
# Update Centroids by Gathering Data Points that are assigned to the same cluster | |
with tf.variable_scope('centroid_updates'): | |
# Keep a running list of centroid update operations | |
centroid_updates = [] | |
for c in range(N_CLUSTERS): | |
# Find the indices of the data points assigned to Cluster C | |
index_mask = tf.squeeze( | |
tf.where( | |
tf.equal(cluster_assignments, c), | |
name='where_equal_to_cluster_%d' % c), | |
name='cluster_%d_indices' % c) | |
# Gather Points assigned to Cluster C | |
cluster_points = tf.gather( | |
params=next_batch, | |
indices=index_mask, | |
name='gather_cluster_points') | |
# Reshape Cluster Points to Have Correct Shape (if null or only one element) | |
cluster_points = tf.reshape(cluster_points, shape=[-1, N_DIMENSIONS], name='cluster_points') | |
# Find the mean and variance of the cluster data | |
cluster_mean, cluster_variance = tf.nn.moments( | |
cluster_points, | |
axes=[0], | |
name='cluster_moments') | |
# Add the update op to the update_op list | |
centroid_update_step = centroids[c].assign(cluster_mean) | |
centroid_updates.append(centroid_update_step) | |
# Run Session | |
with tf.Session(graph=graph) as sess: | |
sess.run(tf.global_variables_initializer()) | |
means = [] | |
# Sample Random points and run batch updates | |
for _ in range(1000): | |
random_indices = np.random.randint(0, X.shape[0], size=BATCH_SIZE) | |
x_ = np.take(X, indices=random_indices, axis=0) | |
means.append(sess.run(centroid_updates, feed_dict={next_batch: x_})) | |
# Get Full Cluster Assignments | |
assignments = sess.run(cluster_assignments, feed_dict={next_batch: X}) | |
# Transform and Transpose Means Array | |
means = np.array(means) | |
means = np.transpose(means, axes=[2,1,0]) | |
# Plot Parameter Trace | |
figure, axes = plt.subplots(2,2, figsize=(14,8), sharey=True, sharex=True) | |
colors = ['b', 'g', 'r'] | |
k = 0; i = 0; j = 0 | |
while (k < N_DIMENSIONS): | |
mean = means[k, :, :] | |
for cluster in range(N_CLUSTERS): | |
a = axes[i,j] | |
axes[i,j].plot(range(mean.shape[-1]), mean[cluster, :], c=colors[cluster]) | |
axes[i,j].set_title('Parameter for Feature %d' % k) | |
if j == 1: | |
j = 0; i += 1; k+=1; | |
else: | |
j+=1; k+=1; | |
figure.suptitle('Parameter Trace', fontsize=32) | |
# Plot Confusion Matrix | |
from sklearn.metrics import confusion_matrix | |
import seaborn as sns | |
figure = plt.figure(figsize=(8,8)) | |
data = confusion_matrix(y, assignments) | |
sns.heatmap(data, annot=True, cmap=plt.cm.get_cmap('summer'), cbar=False) | |
plt.xlabel('Predicted Value') | |
plt.ylabel('Actual Value') | |
plt.show() | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment