Created
March 6, 2020 01:37
-
-
Save bbengfort/da0df7b0a9864ca65c2dbd6701b718d2 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from sklearn.pipeline import Pipeline | |
from sklearn.feature_extraction.text import CountVectorizer | |
from sklearn.decomposition import LatentDirichletAllocation | |
from yellowbrick.datasets import load_hobbies | |
class LDAViz(object): | |
""" | |
Parameters | |
---------- | |
transformer : sklearn.pipeline.Pipeline, optional | |
A user specified vectorizer and LDA model in a pipeline | |
n_components : int, optional (default=10) | |
Number of topics. Not used if transformer is specified. | |
""" | |
def __init__(self, transformer=None, n_components=10): | |
# TODO: accept user parameters | |
self.n_components = n_components | |
self.transformer = transformer | |
@property | |
def transformer(self): | |
return self._transformer | |
@transformer.setter | |
def transformer(self, transformer): | |
if transformer is None: | |
transformer = Pipeline([ | |
("vec", CountVectorizer()), | |
("lda", LatentDirichletAllocation(n_components=self.n_components)), | |
]) | |
if not isinstance(transformer, Pipeline): | |
raise ValueError( | |
"transformer must be a pipeline with a vectorizer and LDA" | |
) | |
self._transformer = transformer | |
def fit(self, X): | |
# Fit the LDA model | |
self.transformer.fit(X) | |
return self | |
def transform(self, X): | |
# TODO: check if fitted | |
# if not hasattr(self.transformer.steps[-1], "components_"): | |
# raise NotFitted("must fit the LDA pipeline before transforming") | |
dtm = self.transformer.steps[0][1].transform(X) | |
doc_lengths_ = dtm.sum(axis=1).getA1() | |
term_freqs_ = dtm.sum(axis=0).getA1() | |
doc_topic_dists_ = self._row_norm(self.transformer.transform(X)) | |
topic_term_dists_ = self._row_norm(self.transformer.steps[-1][1].components_) | |
self.saliencies_ = | |
self.relevencies_ = | |
def fit_transform(self, X): | |
self.fit(X) | |
return self.transform(X) | |
@property | |
def vocab_(self): | |
# TODO: check if fitted | |
# if not hasattr(self.transformer.steps[-1], "components_"): | |
# raise NotFitted("must fit the LDA pipeline before transforming") | |
return self.transformer.steps[0][1].get_feature_names() | |
@staticmethod | |
def _row_norm(dists): | |
return dists / dists.sum(axis=1)[:, None] | |
corpus = load_hobbies() | |
viz = LDAViz() | |
viz.fit_transform(corpus.data) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment