Skip to content

Instantly share code, notes, and snippets.

@JosephCatrambone
Created November 17, 2023 18:28
Show Gist options
  • Save JosephCatrambone/baaef25d338dd6b8b332e76a0445ba0d to your computer and use it in GitHub Desktop.
Save JosephCatrambone/baaef25d338dd6b8b332e76a0445ba0d to your computer and use it in GitHub Desktop.
Cluster Suggestions
import itertools
import sys
import numpy
import torch
from scipy.cluster.vq import kmeans, kmeans2
from transformers import AutoTokenizer, GPT2Model
def magnitude(x):
return numpy.dot(x, x)**0.5
def distance(a, b):
delta = (b - a)
return numpy.dot(delta, delta) ** 0.5
tokenizer = AutoTokenizer.from_pretrained("gpt2")
model = GPT2Model.from_pretrained("gpt2")
def embeddings_from_strings(lines: List[str], normalize: bool = True):
all_embeddings = list()
# Maybe itertools for easy batching:
# islice('ABCDEFG', 2) --> A B
# islice('ABCDEFG', 2, 4) --> C D
# islice('ABCDEFG', 2, None) --> C D E F G
for line in lines:
# We should do this in batches for better GPU utilization.
tokens = tokenizer(line, return_tensors="pt")
model_output = model(**tokens)
embeddings = model_output.last_hidden_state.detach().numpy()[0,0,:]
if normalize:
embeddings /= magnitude(embeddings)
all_embeddings.append(embeddings)
return numpy.vstack(all_embeddings)
def pairwise_similarity(embeddings):
# Assumes one row is one entry.
# Renormalize:
magnitudes = numpy.sum(embeddings * embeddings, axis=1) # This will collapse into a 1D array equal to the number of rows in embeddings.
assert(magnitudes.shape[0] == embeddings.shape[1])
normalized = embeddings / magnitudes[None, ...].T # Broadcast row-norms.
cosine_similarity = normalized @ normalized.T
return cosine_similarity
def cluster(lines, num_clusters: int):
embeddings = embeddings_from_strings(lines)
cluster_centroids = kmeans(whiten(embeddings), k_or_guess = num_clusters)[0] # scipy returns a tuple of centroids and final dist.
clusters = list()
for _ in range(num_clusters):
clusters.append(list())
# The centroids are the 'high level topics'. Match each suggestion to the topic clusters.
for line, emb in zip(line, embeddings):
# Find the nearest centroid to this embedding, then add the line to that group.
min_dist = 1e10
maybe_centroid = 0
for idx, centroid in enumerate(cluster_centroids):
dist = distance(emb, centroid)
if dist < min_dist:
min_dist = dist
maybe_centroid = idx
clusters[maybe_centroid].append(line)
return clusters
def main(filename, num_clusters):
with open(filename, 'rt') as fin:
lines = fin.readlines()
clusters = cluster(lines, num_clusters)
for idx, c in clusters:
print(f"Cluster {idx}: " + ", ".join(c))
if __name__=="__main__":
main(sys.argv[1], sys.argv[2])
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment