Last active
December 19, 2020 12:08
-
-
Save xLaszlo/88e061c2247332d1227af2afaea8a2a0 to your computer and use it in GitHub Desktop.
Simple approximate string search
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import numpy as np | |
import scipy.sparse as sps | |
from collections import Counter | |
from sklearn.feature_extraction.text import CountVectorizer | |
# Use sklearn's vectorizers with a custom tokenizer to turn a string into a one-hot vector of 2 and 3 long substrings. | |
# Store a normalised version in a sparse matrix. | |
class StringSearcher: | |
def __init__(self, names): | |
self.names = [name.lower() for name in names] | |
self._get_common_words() | |
self.vectorizer = CountVectorizer(tokenizer=self._tokenizer) | |
self.vectorizedNames = self.vectorizer.fit_transform(names) | |
tmp = np.sqrt(self.vectorizedNames.power(2).sum(axis=1).A.ravel()) | |
tmp[tmp==0]=1.0 | |
# calculate the constant part of cosine similarity one-off for sparse matrices | |
self.vectorizedNames = sps.diags(1./tmp).dot(self.vectorizedNames) | |
def _tokenizer(self, name): | |
# turns a string into 2 and 3 character long chunks, overweights non-common words three fold, hack the return list for different behaviour | |
def parts(s): | |
return [name[i:i+3] for i in range(len(name)-2)] + [name[i:i+2] for i in range(len(name)-1)] | |
importantPart=' '.join([v for v in name.split(' ') if v not in self.commonWords]) | |
return parts(name) + parts(importantPart) + parts(importantPart) | |
def _get_common_words(self): | |
# counts the common words, hack the set comprehension for different behaviour | |
words=Counter() | |
for name in self.names: | |
words.update(name.split(' ')) | |
self.commonWords={k for k,v in words.items() if len(k)>2 and v>100} | |
def find(self, query, N=10): | |
queryVector = self.vectorizer.transform([query.lower()]).T | |
ind = np.argsort(-self.vectorizedNames.dot(queryVector).A.squeeze()) | |
return [self.names[k] for k in ind[:N]] | |
if __name__ == '__main__': | |
searcher = StringSearcher(names=['Hello', 'World', 'How', 'Are', 'You']) | |
print(searcher.find('Hello')) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment