Skip to content

Instantly share code, notes, and snippets.

@dave-andersen
Last active September 1, 2022 11:15
Show Gist options
  • Star 40 You must be signed in to star a gist
  • Fork 15 You must be signed in to fork a gist
  • Save dave-andersen/265e68a5e879b5540ebc to your computer and use it in GitHub Desktop.
Save dave-andersen/265e68a5e879b5540ebc to your computer and use it in GitHub Desktop.
k-means in Tensorflow
import tensorflow as tf
import numpy as np
import time
N=10000
K=4
MAX_ITERS = 1000
start = time.time()
points = tf.Variable(tf.random_uniform([N,2]))
cluster_assignments = tf.Variable(tf.zeros([N], dtype=tf.int64))
# Silly initialization: Use the first K points as the starting
# centroids. In the real world, do this better.
centroids = tf.Variable(tf.slice(points.initialized_value(), [0,0], [K,2]))
# Replicate to N copies of each centroid and K copies of each
# point, then subtract and compute the sum of squared distances.
rep_centroids = tf.reshape(tf.tile(centroids, [N, 1]), [N, K, 2])
rep_points = tf.reshape(tf.tile(points, [1, K]), [N, K, 2])
sum_squares = tf.reduce_sum(tf.square(rep_points - rep_centroids),
reduction_indices=2)
# Use argmin to select the lowest-distance point
best_centroids = tf.argmin(sum_squares, 1)
did_assignments_change = tf.reduce_any(tf.not_equal(best_centroids,
cluster_assignments))
def bucket_mean(data, bucket_ids, num_buckets):
total = tf.unsorted_segment_sum(data, bucket_ids, num_buckets)
count = tf.unsorted_segment_sum(tf.ones_like(data), bucket_ids, num_buckets)
return total / count
means = bucket_mean(points, best_centroids, K)
# Do not write to the assigned clusters variable until after
# computing whether the assignments have changed - hence with_dependencies
with tf.control_dependencies([did_assignments_change]):
do_updates = tf.group(
centroids.assign(means),
cluster_assignments.assign(best_centroids))
init = tf.initialize_all_variables()
sess = tf.Session()
sess.run(init)
changed = True
iters = 0
while changed and iters < MAX_ITERS:
iters += 1
[changed, _] = sess.run([did_assignments_change, do_updates])
[centers, assignments] = sess.run([centroids, cluster_assignments])
end = time.time()
print ("Found in %.2f seconds" % (end-start)), iters, "iterations"
print "Centroids:"
print centers
print "Cluster assignments:", assignments
@kzhang28
Copy link

Hi dave-andersen,
Based on your code, If I want to achieve a distributed version of kmeans algorithm in distributed tensorflow environment (say 1 ps+n workers). What's a good way to synchronize across workers? Do you have an idea about that?

@rawoke083
Copy link

How would I add more dimensions (features) ?

@yunxunmi
Copy link

can Text Clustering?
how to do?

@yindia
Copy link

yindia commented May 21, 2017

For text clustering first of all convert your dataset into vector using TfidfVectorizer and then apply any clustering algo.

For more deep use https://github.com/nfmcclure/tensorflow_cookbook/tree/master/07_Natural_Language_Processing

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment