Skip to content

Instantly share code, notes, and snippets.

@xLaszlo
Last active December 19, 2020 12:08
Show Gist options
  • Save xLaszlo/88e061c2247332d1227af2afaea8a2a0 to your computer and use it in GitHub Desktop.
Save xLaszlo/88e061c2247332d1227af2afaea8a2a0 to your computer and use it in GitHub Desktop.
Simple approximate string search
import numpy as np
import scipy.sparse as sps
from collections import Counter
from sklearn.feature_extraction.text import CountVectorizer
# Use sklearn's vectorizers with a custom tokenizer to turn a string into a one-hot vector of 2 and 3 long substrings.
# Store a normalised version in a sparse matrix.
class StringSearcher:
def __init__(self, names):
self.names = [name.lower() for name in names]
self._get_common_words()
self.vectorizer = CountVectorizer(tokenizer=self._tokenizer)
self.vectorizedNames = self.vectorizer.fit_transform(names)
tmp = np.sqrt(self.vectorizedNames.power(2).sum(axis=1).A.ravel())
tmp[tmp==0]=1.0
# calculate the constant part of cosine similarity one-off for sparse matrices
self.vectorizedNames = sps.diags(1./tmp).dot(self.vectorizedNames)
def _tokenizer(self, name):
# turns a string into 2 and 3 character long chunks, overweights non-common words three fold, hack the return list for different behaviour
def parts(s):
return [name[i:i+3] for i in range(len(name)-2)] + [name[i:i+2] for i in range(len(name)-1)]
importantPart=' '.join([v for v in name.split(' ') if v not in self.commonWords])
return parts(name) + parts(importantPart) + parts(importantPart)
def _get_common_words(self):
# counts the common words, hack the set comprehension for different behaviour
words=Counter()
for name in self.names:
words.update(name.split(' '))
self.commonWords={k for k,v in words.items() if len(k)>2 and v>100}
def find(self, query, N=10):
queryVector = self.vectorizer.transform([query.lower()]).T
ind = np.argsort(-self.vectorizedNames.dot(queryVector).A.squeeze())
return [self.names[k] for k in ind[:N]]
if __name__ == '__main__':
searcher = StringSearcher(names=['Hello', 'World', 'How', 'Are', 'You'])
print(searcher.find('Hello'))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment