Skip to content

Instantly share code, notes, and snippets.

@lefnire
Last active September 16, 2022 13:27
Show Gist options
  • Star 1 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save lefnire/146631ad327e47267bef0334edecbce5 to your computer and use it in GitHub Desktop.
Save lefnire/146631ad327e47267bef0334edecbce5 to your computer and use it in GitHub Desktop.
BERT embeddings similarity & clustering
import math
import torch
from sklearn.cluster import AgglomerativeClustering
from sentence_transformers import SentenceTransformer
def normalize(x):
"""
Need this method for cosine(), because otherwise we're dealing with dot-product on non-normalized data (which isn't cosine).
BERT embeddings typically range [-7 7], so this is needed.
"""
return x / x.norm(dim=1)[:, None]
def cosine(x, y=None, norm=False):
"""
x: LHS (ie, your query documents)
y: RHS to compare against (ie, the database). If not present, you're performing pairwise_cosine (x cosine x)
norm: normalize output? something to fiddle with. I've found it improves Agglomorative Clustering for some reason?
"""
x = torch.tensor(x)
if y is None:
x = y = normalize(x)
else:
y = torch.tensor(y)
# normalize together first
both = torch.cat((x, y), 0)
both = normalize(both)
x, y = both[:x.shape[0]], both[x.shape[0]:]
sim = torch.mm(x, y.T)
if norm: sim = normalize(sim)
# this part key. For sklearn hierarchical clustering models, value must be greater than 1. I've fiddled with all sorts
# of positive-ify cosine similarities (which is naturally between [-1 1]). This includes absolute(), sim.acos()/math.pi, etc.
sim = (sim + 1.) / 2.
# Prefer working with dist than sim, since most default sorting (eg numpy, pandas) is ascending
dist = 1. - sim
return dist.numpy()
def cluster(x):
"""
I've found AgglomorativeClustering on precomputed cosine similarities the best performing clusterer for BERT embeddings.
I've tried sklearn.KMeans, faiss.KMeans, DBSCAN, HDBSCAN, and others. Unfortunately, unlike KMeans with its elbow-method
on silhouette scores, I don't know how to find optimal clusters for Agglomorative
"""
dists = cosine(x, norm=True)
nc = math.floor(1 + 4 * math.log10(dists.shape[0])) # kinda odd-ball good default val for my dataset
agg = AgglomerativeClustering(n_clusters=nc, affinity='precomputed', linkage='average')
return agg.fit_predict(dists)
# ukplab/sentence_transformer models trained on nli-stsb datasets fine-tune on cosine similarity of documents,
# so we prefer these models over others (even newer Transformers models like Longformer) due to their preserving cosine
# similarity in the embedding output. roberta-base is my favorite.
encoder = SentenceTransformer('roberta-base-nli-stsb-mean-tokens')
db = encoder.encode(sentences_from_somewhere)
query = encoder.encode(["Find closest matches to this sentence"])
dists = cosine(query, db) # do your own argsort or whatever
db_labels = cluster(db)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment