Last active
September 16, 2022 13:27
-
-
Save lefnire/146631ad327e47267bef0334edecbce5 to your computer and use it in GitHub Desktop.
BERT embeddings similarity & clustering
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import math | |
import torch | |
from sklearn.cluster import AgglomerativeClustering | |
from sentence_transformers import SentenceTransformer | |
def normalize(x): | |
""" | |
Need this method for cosine(), because otherwise we're dealing with dot-product on non-normalized data (which isn't cosine). | |
BERT embeddings typically range [-7 7], so this is needed. | |
""" | |
return x / x.norm(dim=1)[:, None] | |
def cosine(x, y=None, norm=False): | |
""" | |
x: LHS (ie, your query documents) | |
y: RHS to compare against (ie, the database). If not present, you're performing pairwise_cosine (x cosine x) | |
norm: normalize output? something to fiddle with. I've found it improves Agglomorative Clustering for some reason? | |
""" | |
x = torch.tensor(x) | |
if y is None: | |
x = y = normalize(x) | |
else: | |
y = torch.tensor(y) | |
# normalize together first | |
both = torch.cat((x, y), 0) | |
both = normalize(both) | |
x, y = both[:x.shape[0]], both[x.shape[0]:] | |
sim = torch.mm(x, y.T) | |
if norm: sim = normalize(sim) | |
# this part key. For sklearn hierarchical clustering models, value must be greater than 1. I've fiddled with all sorts | |
# of positive-ify cosine similarities (which is naturally between [-1 1]). This includes absolute(), sim.acos()/math.pi, etc. | |
sim = (sim + 1.) / 2. | |
# Prefer working with dist than sim, since most default sorting (eg numpy, pandas) is ascending | |
dist = 1. - sim | |
return dist.numpy() | |
def cluster(x): | |
""" | |
I've found AgglomorativeClustering on precomputed cosine similarities the best performing clusterer for BERT embeddings. | |
I've tried sklearn.KMeans, faiss.KMeans, DBSCAN, HDBSCAN, and others. Unfortunately, unlike KMeans with its elbow-method | |
on silhouette scores, I don't know how to find optimal clusters for Agglomorative | |
""" | |
dists = cosine(x, norm=True) | |
nc = math.floor(1 + 4 * math.log10(dists.shape[0])) # kinda odd-ball good default val for my dataset | |
agg = AgglomerativeClustering(n_clusters=nc, affinity='precomputed', linkage='average') | |
return agg.fit_predict(dists) | |
# ukplab/sentence_transformer models trained on nli-stsb datasets fine-tune on cosine similarity of documents, | |
# so we prefer these models over others (even newer Transformers models like Longformer) due to their preserving cosine | |
# similarity in the embedding output. roberta-base is my favorite. | |
encoder = SentenceTransformer('roberta-base-nli-stsb-mean-tokens') | |
db = encoder.encode(sentences_from_somewhere) | |
query = encoder.encode(["Find closest matches to this sentence"]) | |
dists = cosine(query, db) # do your own argsort or whatever | |
db_labels = cluster(db) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment