Skip to content

Instantly share code, notes, and snippets.

@bvarghese1
Last active April 15, 2019 04:35
Show Gist options
  • Save bvarghese1/0ab70a15b077945b8f027dc797a79ebd to your computer and use it in GitHub Desktop.
Save bvarghese1/0ab70a15b077945b8f027dc797a79ebd to your computer and use it in GitHub Desktop.
import os
import numpy as np
from sklearn.cluster import Kmeans
from cluster import Cluster
class KmeansCluster(Cluster):
def __init__(self, path):
super(KmeansCluster, self).__init__(path)
# Implementation of the base class abstract method
def cluster(self, features, num_clusters):
# Create the clustering algorithm(Kmeans)
clustering_algo = KmeansClustering(n_clusters=num_clusters)
# Train with the data
clustering_algo.fit(features)
# Extract the assigned cluster labels
labels = clustering_algo.labels_
# Generate centroids using the features and assigned cluster labels
data = np.empty((0, features.shape[1]), 'float32')
for i in range(num_clusters):
row = np.dot(labels == i, embeddings) / np.sum(labels == i)
data = np.vstack((data, row))
# Normalize the centroids
tdata = data.transpose()
centroids = (tdata / np.sqrt(np.sum(tdata * tdata, axis=0))).transpose()
# Save the centroids
np.save(os.path.join(self.path, "kmeans_centroids"), centroids)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment