Skip to content

Instantly share code, notes, and snippets.

@colinpollock
Created February 12, 2023 03:28
Show Gist options
  • Save colinpollock/68569c7bc8dd159e037c20fc14e73a25 to your computer and use it in GitHub Desktop.
Save colinpollock/68569c7bc8dd159e037c20fc14e73a25 to your computer and use it in GitHub Desktop.
"""Make initial clusters of categories to bootstrap top-level categories."""
from collections import defaultdict
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.cluster import KMeans, MiniBatchKMeans
from j_util import get_rows
def get_docs():
cat_to_clues = defaultdict(list)
for row in get_rows()
row['category'] = row['category'].lower()
cat_to_clues[row['category']].append(row)
docs = []
categories = []
for category, clues in cat_to_clues.items():
clue_answer_string = '\t'.join([
'{}: {}'.format(clue['clue'], clue['answer'])
for clue in clues
])
docs.append(category + '\n' + clue_answer_string)
categories.append(category)
return docs, categories, cat_to_clues
def cluster(vectorizer, docs):
verbose = False
num_clusters = 100
X = vectorizer.fit_transform(docs)
km = MiniBatchKMeans(n_clusters=num_clusters, init='k-means++', n_init=1,
init_size=1000, batch_size=1000, verbose=verbose)
km.fit(X)
return km
def top_cluster_words(vectorizer, km, cluster_idx):
order_centroids = km.cluster_centers_.argsort()[:, ::-1]
terms = vectorizer.get_feature_names()
print("Cluster %d:" % cluster_idx, end='', file=out)
return [terms[indx] for ind in order_centroids[cluster_idx,:]]
DOCS, CATEGORIES, category_to_clues = get_docs()
num_features = 10000
use_idf = True
vectorizer = TfidfVectorizer(max_df=0.5, max_features=num_features,
min_df=2, stop_words='english',
use_idf=use_idf)
km = cluster(vectorizer, DOCS)
def get_categories_in_cluster(category):
category_idx = CATEGORIES.index(category)
cluster_idx = km.labels_[category_idx]
other_categories_idxs = {
category_idx
for (category_idx, cluster) in enumerate(km.labels_)
if cluster == cluster_idx
}
return [cat for (idx, cat) in enumerate(CATEGORIES) if idx in other_categories_idxs]
NUM_CATEGORIES = 100
sorted_cats = sorted(category_to_clues.keys(), key=lambda category: -len(category_to_clues[category]))[:NUM_CATEGORIES]
seen_categories = set()
for idx, main_category in enumerate(sorted_cats, start=1):
categories_in_cluster = get_categories_in_cluster(main_category)
categories = [main_category] + categories_in_cluster
categories = [c for c in categories if c not in seen_categories]
s = sum(len(category_to_clues[cat]) for cat in categories)
if s == 0:
continue
print('## Category {:d}: {} ({} clues)##'.format(idx, main_category, len(category_to_clues[main_category])))
print('cluster: {} categories, {} clues'.format(len(categories), s))
categories.sort(key=lambda cat: -len(category_to_clues[cat]))
for category in categories:
num_clues = len(category_to_clues[category])
if num_clues < 25:
continue
print(' Cat: {} ({} clues)'.format(category, num_clues))
for clue in category_to_clues[category][:3]:
print(' Clue: {}: {}'.format(clue['clue'], clue['answer']))
seen_categories.update(categories)
print()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment