Skip to content

Instantly share code, notes, and snippets.

Show Gist options
  • Save sanchezg/82b72105491c2573cd0df9b50f83b3e3 to your computer and use it in GitHub Desktop.
Save sanchezg/82b72105491c2573cd0df9b50f83b3e3 to your computer and use it in GitHub Desktop.
from sklearn.feature_extraction.text import CountVectorizer
class ProjectionCountVectorizer(CountVectorizer):
def __init__(self, projection_path, *args, **kwargs):
self.projection_path = projection_path.split('/')
super().__init__(*args, **kwargs)
def build_preprocessor(self):
built = super().build_preprocessor()
def projection_and_preprocess(doc):
return built(self.do_projection(doc))
return projection_and_preprocess
def do_projection(self, doc):
for step in self.projection_path:
if isinstance(doc, dict):
doc = doc[step]
elif isinstance(doc, (tuple, list)):
if step.isdigit():
doc = doc[int(step)]
else: # only valid for namedtuples
doc = getattr(doc, step)
else:
raise ValueError('cant apply step %s' % step)
return doc
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment