Created
November 17, 2023 18:28
-
-
Save JosephCatrambone/baaef25d338dd6b8b332e76a0445ba0d to your computer and use it in GitHub Desktop.
Cluster Suggestions
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import itertools | |
import sys | |
import numpy | |
import torch | |
from scipy.cluster.vq import kmeans, kmeans2 | |
from transformers import AutoTokenizer, GPT2Model | |
def magnitude(x): | |
return numpy.dot(x, x)**0.5 | |
def distance(a, b): | |
delta = (b - a) | |
return numpy.dot(delta, delta) ** 0.5 | |
tokenizer = AutoTokenizer.from_pretrained("gpt2") | |
model = GPT2Model.from_pretrained("gpt2") | |
def embeddings_from_strings(lines: List[str], normalize: bool = True): | |
all_embeddings = list() | |
# Maybe itertools for easy batching: | |
# islice('ABCDEFG', 2) --> A B | |
# islice('ABCDEFG', 2, 4) --> C D | |
# islice('ABCDEFG', 2, None) --> C D E F G | |
for line in lines: | |
# We should do this in batches for better GPU utilization. | |
tokens = tokenizer(line, return_tensors="pt") | |
model_output = model(**tokens) | |
embeddings = model_output.last_hidden_state.detach().numpy()[0,0,:] | |
if normalize: | |
embeddings /= magnitude(embeddings) | |
all_embeddings.append(embeddings) | |
return numpy.vstack(all_embeddings) | |
def pairwise_similarity(embeddings): | |
# Assumes one row is one entry. | |
# Renormalize: | |
magnitudes = numpy.sum(embeddings * embeddings, axis=1) # This will collapse into a 1D array equal to the number of rows in embeddings. | |
assert(magnitudes.shape[0] == embeddings.shape[1]) | |
normalized = embeddings / magnitudes[None, ...].T # Broadcast row-norms. | |
cosine_similarity = normalized @ normalized.T | |
return cosine_similarity | |
def cluster(lines, num_clusters: int): | |
embeddings = embeddings_from_strings(lines) | |
cluster_centroids = kmeans(whiten(embeddings), k_or_guess = num_clusters)[0] # scipy returns a tuple of centroids and final dist. | |
clusters = list() | |
for _ in range(num_clusters): | |
clusters.append(list()) | |
# The centroids are the 'high level topics'. Match each suggestion to the topic clusters. | |
for line, emb in zip(line, embeddings): | |
# Find the nearest centroid to this embedding, then add the line to that group. | |
min_dist = 1e10 | |
maybe_centroid = 0 | |
for idx, centroid in enumerate(cluster_centroids): | |
dist = distance(emb, centroid) | |
if dist < min_dist: | |
min_dist = dist | |
maybe_centroid = idx | |
clusters[maybe_centroid].append(line) | |
return clusters | |
def main(filename, num_clusters): | |
with open(filename, 'rt') as fin: | |
lines = fin.readlines() | |
clusters = cluster(lines, num_clusters) | |
for idx, c in clusters: | |
print(f"Cluster {idx}: " + ", ".join(c)) | |
if __name__=="__main__": | |
main(sys.argv[1], sys.argv[2]) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment