Skip to content

Instantly share code, notes, and snippets.

@bbengfort
Created March 6, 2020 01:37
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save bbengfort/da0df7b0a9864ca65c2dbd6701b718d2 to your computer and use it in GitHub Desktop.
Save bbengfort/da0df7b0a9864ca65c2dbd6701b718d2 to your computer and use it in GitHub Desktop.
from sklearn.pipeline import Pipeline
from sklearn.feature_extraction.text import CountVectorizer
from sklearn.decomposition import LatentDirichletAllocation
from yellowbrick.datasets import load_hobbies
class LDAViz(object):
"""
Parameters
----------
transformer : sklearn.pipeline.Pipeline, optional
A user specified vectorizer and LDA model in a pipeline
n_components : int, optional (default=10)
Number of topics. Not used if transformer is specified.
"""
def __init__(self, transformer=None, n_components=10):
# TODO: accept user parameters
self.n_components = n_components
self.transformer = transformer
@property
def transformer(self):
return self._transformer
@transformer.setter
def transformer(self, transformer):
if transformer is None:
transformer = Pipeline([
("vec", CountVectorizer()),
("lda", LatentDirichletAllocation(n_components=self.n_components)),
])
if not isinstance(transformer, Pipeline):
raise ValueError(
"transformer must be a pipeline with a vectorizer and LDA"
)
self._transformer = transformer
def fit(self, X):
# Fit the LDA model
self.transformer.fit(X)
return self
def transform(self, X):
# TODO: check if fitted
# if not hasattr(self.transformer.steps[-1], "components_"):
# raise NotFitted("must fit the LDA pipeline before transforming")
dtm = self.transformer.steps[0][1].transform(X)
doc_lengths_ = dtm.sum(axis=1).getA1()
term_freqs_ = dtm.sum(axis=0).getA1()
doc_topic_dists_ = self._row_norm(self.transformer.transform(X))
topic_term_dists_ = self._row_norm(self.transformer.steps[-1][1].components_)
self.saliencies_ =
self.relevencies_ =
def fit_transform(self, X):
self.fit(X)
return self.transform(X)
@property
def vocab_(self):
# TODO: check if fitted
# if not hasattr(self.transformer.steps[-1], "components_"):
# raise NotFitted("must fit the LDA pipeline before transforming")
return self.transformer.steps[0][1].get_feature_names()
@staticmethod
def _row_norm(dists):
return dists / dists.sum(axis=1)[:, None]
corpus = load_hobbies()
viz = LDAViz()
viz.fit_transform(corpus.data)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment