Skip to content

Instantly share code, notes, and snippets.

@kmike
Last active March 16, 2022 09:54
Show Gist options
  • Star 6 You must be signed in to star a gist
  • Fork 2 You must be signed in to fork a gist
  • Save kmike/7814472 to your computer and use it in GitHub Desktop.
Save kmike/7814472 to your computer and use it in GitHub Desktop.
import marisa_trie
from sklearn.feature_extraction.text import CountVectorizer, TfidfVectorizer
# hack to store vocabulary in MARISA Trie
class _MarisaVocabularyMixin(object):
def fit_transform(self, raw_documents, y=None):
super(_MarisaVocabularyMixin, self).fit_transform(raw_documents)
self._freeze_vocabulary()
return super(_MarisaVocabularyMixin, self).fit_transform(raw_documents, y)
def _freeze_vocabulary(self):
if not self.fixed_vocabulary_:
self.vocabulary_ = marisa_trie.Trie(self.vocabulary_.keys())
self.fixed_vocabulary_ = True
del self.stop_words_
class MarisaCountVectorizer(_MarisaVocabularyMixin, CountVectorizer):
pass
class MarisaTfidfVectorizer(_MarisaVocabularyMixin, TfidfVectorizer):
def fit(self, raw_documents, y=None):
super(MarisaTfidfVectorizer, self).fit(raw_documents)
self._freeze_vocabulary()
return super(MarisaTfidfVectorizer, self).fit(raw_documents, y)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment