Created November 17, 2023 18:28
Cluster Suggestions
import itertools
import sys
import numpy
import torch
from scipy.cluster.vq import kmeans, kmeans2
from transformers import AutoTokenizer, GPT2Model
def magnitude(x):
return, x)**0.5
def distance(a, b):
delta = (b - a)
return, delta) ** 0.5
tokenizer = AutoTokenizer.from_pretrained("gpt2")
model = GPT2Model.from_pretrained("gpt2")
def embeddings_from_strings(lines: List[str], normalize: bool = True):
all_embeddings = list()
# Maybe itertools for easy batching:
# islice('ABCDEFG', 2) --> A B
# islice('ABCDEFG', 2, 4) --> C D
# islice('ABCDEFG', 2, None) --> C D E F G
for line in lines:
# We should do this in batches for better GPU utilization.
tokens = tokenizer(line, return_tensors="pt")
model_output = model(**tokens)
embeddings = model_output.last_hidden_state.detach().numpy()[0,0,:]
if normalize:
embeddings /= magnitude(embeddings)
return numpy.vstack(all_embeddings)
def pairwise_similarity(embeddings):
# Assumes one row is one entry.
# Renormalize:
magnitudes = numpy.sum(embeddings * embeddings, axis=1) # This will collapse into a 1D array equal to the number of rows in embeddings.
assert(magnitudes.shape[0] == embeddings.shape[1])
normalized = embeddings / magnitudes[None, ...].T # Broadcast row-norms.
cosine_similarity = normalized @ normalized.T
return cosine_similarity
def cluster(lines, num_clusters: int):
embeddings = embeddings_from_strings(lines)
cluster_centroids = kmeans(whiten(embeddings), k_or_guess = num_clusters)[0] # scipy returns a tuple of centroids and final dist.
clusters = list()
for _ in range(num_clusters):
# The centroids are the 'high level topics'. Match each suggestion to the topic clusters.
for line, emb in zip(line, embeddings):
# Find the nearest centroid to this embedding, then add the line to that group.
min_dist = 1e10
maybe_centroid = 0
for idx, centroid in enumerate(cluster_centroids):
dist = distance(emb, centroid)
if dist < min_dist:
min_dist = dist
maybe_centroid = idx
return clusters
def main(filename, num_clusters):
with open(filename, 'rt') as fin:
lines = fin.readlines()
clusters = cluster(lines, num_clusters)
for idx, c in clusters:
print(f"Cluster {idx}: " + ", ".join(c))
if __name__=="__main__":
main(sys.argv[1], sys.argv[2])
