Created
July 30, 2014 16:01
-
-
Save vene/9122e78164f96effc0a6 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import re | |
from collections import OrderedDict | |
import numpy as np | |
from sklearn.base import BaseEstimator, TransformerMixin | |
class LexicalSetVectorizer(BaseEstimator, TransformerMixin): | |
def __init__(self, word_sets=None, normalize=False, lower=False, | |
token_pattern=ur'(?u)\b\w\w+\b'): | |
self.word_sets = word_sets | |
self.normalize = normalize | |
self.lower = lower | |
self.token_pattern = token_pattern | |
def fit(self, X, y=None): | |
return self | |
def transform(self, X, y=None): | |
word_sets = self.word_sets | |
word_sets = OrderedDict(sorted(word_sets.items())) if word_sets else {} | |
self.feature_names_ = word_sets.keys() | |
token_pattern = re.compile(self.token_pattern) | |
counts = np.zeros((len(X), len(word_sets)), dtype=np.float) | |
for row, doc in enumerate(X): | |
doc = doc.lower() if self.lower else doc | |
tokenized_doc = token_pattern.findall(doc) | |
for col, word_set in enumerate(word_sets.values()): | |
if hasattr(word_set, "match"): # word_set is given as a regex | |
count = len(word_set.findall(doc)) | |
elif hasattr(word_set, "upper"): # word_set is a string | |
count = sum(word == word_set for word in tokenized_doc) | |
else: | |
count = sum(word == ref | |
for word in tokenized_doc | |
for ref in word_set) | |
counts[row, col] = count | |
if self.normalize: | |
counts[row, :] /= len(tokenized_doc) | |
return counts | |
def get_feature_names(self): | |
if not hasattr(self, "feature_names_"): | |
self.feature_names_ = sorted(self.word_sets.keys()) | |
return self.feature_names_ |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import re | |
from numpy.testing import assert_array_equal | |
from lexical_set_vectorizer import LexicalSetVectorizer | |
def test_simple_example(): | |
vect = LexicalSetVectorizer(dict(yes="yes", | |
no=["no", "nope"], | |
conj=re.compile(r"\b(but|and)")), | |
normalize=False) | |
X = vect.fit_transform(["I would say yes but... yess", | |
"yes and no", | |
"nope nope and... nope"]) | |
features = vect.get_feature_names() | |
assert_array_equal(X[:, features.index("yes")], [1, 1, 0]) | |
assert_array_equal(X[:, features.index("no")], [0, 1, 3]) | |
assert_array_equal(X[:, features.index("conj")], [1, 1, 1]) | |
def test_normalize(): | |
vect = LexicalSetVectorizer(dict(yes="yes", no="no"), normalize=True) | |
X = vect.fit_transform(["yes yes", "no no no no no no no"]) | |
assert_array_equal(X.sum(axis=0), [1, 1]) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment