Last active
August 29, 2015 14:10
-
-
Save frnsys/3ab30df1bc160c6f4eeb to your computer and use it in GitHub Desktop.
simple HAC implementation
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import operator | |
from itertools import combinations | |
from functools import reduce | |
import numpy as np | |
def hac(vecs, sim_func, threshold): | |
""" | |
Hierarchical Agglomerative Clustering. | |
""" | |
def sim(pair): | |
a = np.array([vecs[i] for i in pair[0]]) | |
b = np.array([vecs[i] for i in pair[1]]) | |
return sim_func(a, b) | |
size = len(vecs) | |
labels = [0 for i in range(size)] | |
# Each item starts in its own cluster. | |
clusters = {(i,) for i in range(size)} | |
# Initialize the label. | |
j = len(clusters) + 1 | |
while True: | |
pairs = combinations(clusters, 2) | |
# Calculate the similarity for each pair. | |
# (pair, similarity) | |
scores = [(p, sim(p)) for p in pairs] | |
# Get the highest similarity to determine which pair is merged. | |
mxm = max(scores, key=operator.itemgetter(1)) | |
# Stop if the highest similarity is below the threshold. | |
if mxm[1] < threshold: | |
break | |
# Remove the to-be-merged pair from the set of clusters, | |
# then merge (flatten) them. | |
pair = mxm[0] | |
clusters -= set(pair) | |
flat_pair = reduce(lambda x,y: x + y, pair) | |
# Update the labels for the pairs' members. | |
for i in flat_pair: | |
labels[i] = j | |
# Add the new cluster to the clusters. | |
clusters.add(flat_pair) | |
# If one cluster is left, we can't continue merging. | |
if len(clusters) == 1: | |
break | |
# Increment the label. | |
j += 1 | |
return labels |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment