Skip to content

Instantly share code, notes, and snippets.

@bvarghese1
Created April 13, 2019 20:29
Show Gist options
  • Save bvarghese1/bf9dabb64533b5040e1dc9cc47b58535 to your computer and use it in GitHub Desktop.
Save bvarghese1/bf9dabb64533b5040e1dc9cc47b58535 to your computer and use it in GitHub Desktop.
import os
import time
import numpy as np
from abc import abstractmethod
class ClusterTemplate(object):
def __init__(self, path):
self.path = path
if not os.path.exists(self.path):
os.makedirs(self.path, exist_ok=True)
self.centroids = None
self.clustering_algo = None
@abstractmethod
def init_cluster_algo(self, num_clusters):
raise NotImplementedError("Abstract method 'init_cluster_algo' not implemented")
def train(self, embeddings):
self.clustering_algo.fit(embeddings)
def should_normalize(self):
return True
def normalize(self, embeddings, num_clusters):
# Extract the assigned cluster labels
labels = self.clustering_algo.labels_
# Generate centroids using the features and assigned cluster labels
data = np.empty((0, features.shape[1]), 'float32')
for i in range(num_clusters):
row = np.dot(labels == i, embeddings) / np.sum(labels == i)
data = np.vstack((data, row))
# Normalize
tdata = data.transpose()
self.centroids = (tdata / np.sqrt(np.sum(tdata * tdata, axis=0))).transpose()
def save(self, cluster_name):
np.save(os.path.join(self.path, cluster_name), self.centroids)
# Final method that no sub class must override. Should be invoked directly from the client
def cluster(self, features, cluster_name, niter=20, num_clusters=100):
self.init_cluster_algo(num_clusters)
self.train(embeddings)
if self.should_normalize():
self.normalize(embeddings, num_clusters)
self.save(cluster_name)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment