Skip to content

Instantly share code, notes, and snippets.

@sbarratt
Created November 3, 2017 00:21
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save sbarratt/42da6a307e27c15ae00e783cd13cd36e to your computer and use it in GitHub Desktop.
Save sbarratt/42da6a307e27c15ae00e783cd13cd36e to your computer and use it in GitHub Desktop.
K-means script that works with NaN entries.
"""
Author: Shane Barratt
Email: sbarratt@stanford.edu
K-means script that works with NaN entries.
"""
import numpy as np
import IPython as ipy
import matplotlib.pyplot as plt
def kmeans(data, k, max_iterations):
# initialize centroids randomly
num_features = data.shape[1]
centroids = _get_random_centroids(num_features, k, data)
# initialize bookkeeping variables
iterations = 0
old_centroids = None
# run the k-means algorithm
while not _should_stop(old_centroids, centroids, iterations, max_iterations):
old_centroids = centroids
iterations += 1
# Assign labels to each datapoint based on centroids
labels = _get_labels(data, centroids)
# Assign centroids based on datapoint labels
centroids = _get_centroids(data, labels, k)
return centroids, _get_labels(data, centroids)
def _distance(x, centroids):
return np.sqrt(np.nansum((x - centroids)**2, axis=1))
def _mean(data):
return np.nan_to_num(np.nanmean(data, axis=0))
def _get_random_centroids(num_features, k, data):
centroids = np.random.normal(size=(k, num_features))
return centroids
def _should_stop(old_centroids, centroids, iterations, max_iterations):
# Stop if centroids haven't changed or if a certain number of iterations have passed
if iterations > max_iterations: return True
if old_centroids is None:
return False
return np.all(np.equal(old_centroids, centroids))
def _get_labels(data, centroids):
# For each element in the data, chose the closest centroid.
# Make that centroid the element's label.
N = data.shape[0]
labels = np.zeros(N)
for i in range(N):
x = data[i, :]
dist_to_centroids = _distance(x, centroids)
labels[i] = np.argmin(dist_to_centroids)
return labels
def _get_centroids(data, labels, k):
# Each centroid is the geometric mean of the points that
# have that centroid's label.
centroids = np.zeros((k, data.shape[1]))
for j in range(k):
centroid_data = data[labels == j, :]
if centroid_data.shape[0] == 0: # randomly re-intialize cluster
centroids[j, :] = data[np.random.choice(np.arange(k)), :]
else:
centroids[j, :] = _mean(centroid_data)
return centroids
if __name__ == '__main__':
# Data is mixture of gaussians at [1, 1] and [-1, -1]
data = np.random.multivariate_normal(np.ones(2), np.eye(2), size=1000)
data = np.r_[data, np.random.multivariate_normal(-np.ones(2), np.eye(2), size=1000)]
# Randomly nan out entries of the data
for i in range(data.shape[0]):
for j in range(data.shape[1]):
if np.random.random() < .2:
data[i, j] *= np.nan
# Run k-means and plot centroids
centroids, labels = kmeans(data, 2, 200)
plt.scatter(centroids[:,0], centroids[:, 1], c='r')
plt.show()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment