Skip to content

Instantly share code, notes, and snippets.

@kkew3
Last active December 23, 2023 03:42
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save kkew3/b9fb85ef390685c13733a3e006a7e825 to your computer and use it in GitHub Desktop.
Save kkew3/b9fb85ef390685c13733a3e006a7e825 to your computer and use it in GitHub Desktop.
The approach to evaluate scikit-learn topic model in terms of coherence with gensim using existing vocabulary.
from collections import Counter
from typing import Dict, Union, List
import numpy as np
from scipy import sparse
import pandas as pd
import spacy
from sklearn.datasets import fetch_20newsgroups
from sklearn.decomposition import LatentDirichletAllocation
from gensim.models.coherencemodel import CoherenceModel
nlp = spacy.load('en_core_web_md', disable=['ner', 'parser'])
corpus = fetch_20newsgroups(remove=('headers', 'footers', 'quotes')).data
def map_filter_words(doc):
"""Filtering and lemmatization."""
for word in doc:
if word.is_alpha and not word.is_stop:
yield word.lemma_
texts = [] # tokenized corpus
tf = Counter() # global term frequency
for doc in map(nlp, corpus):
tf.update(map_filter_words(doc))
texts.append(list(map_filter_words(doc)))
vocab = sorted(tf) # the vocabulary
doc_word = sparse.lil_matrix((len(texts), len(vocab)), dtype=int)
for i, doc in enumerate(texts):
doc_tf = Counter(doc) # term frequency per document
r = pd.Series(doc_tf).reindex(vocab).fillna(0).astype(int)
doc_word[i] = r.to_numpy()
doc_word = doc_word.tocsr()
lda = LatentDirichletAllocation()
lda.fit(doc_word)
### NOTE HERE
class DummyTopicModel:
"""Fake a topic model for gensim"""
def __init__(self, lam):
"""lam: the variational parameters for topic-word distribution"""
self.lam = lam / np.sum(lam, axis=1, keepdims=True)
def get_topics(self):
return self.lam
### NOTE HERE
class DummyDictionary:
def __init__(self, vocab: ty.List[str]):
self.token2id = {w: j for j, w in enumerate(vocab)}
self.id2token = vocab.copy()
def __getitem__(self, item):
return self.id2token[item]
def __contains__(self, item):
if isinstance(item, int):
return 0 <= item < len(self.id2token)
return False
cm = CoherenceModel(
model=DummyTopicModel(lda.components_),
texts=texts,
dictionary=DummyDictionary(vocab),
coherence='c_npmi',
)
coh = np.asarray(cm.get_coherence_per_topic())
print('average topic coherence:', np.mean(coh))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment