Last active
May 19, 2023 10:04
-
-
Save cpcdoy/37b73944b605f4b51b442a7b1cc84a87 to your computer and use it in GitHub Desktop.
KB Trie example implementation using datrie
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import string | |
from itertools import chain, islice | |
import datrie # type: ignore | |
class KBTrie: | |
""" | |
Knowledge Base with a Trie structure. | |
Does consecutive n-gram decomposition of a text and then uses the Trie to look-up. | |
""" | |
def __init__( | |
self, | |
language_alphabet=string.printable, | |
labels=None, | |
): | |
""" | |
Init KBTrie | |
""" | |
if not labels: | |
labels = ["org", "loc", "per", "date"] | |
self.special_tokens = [f"<beg_{l}> " for l in labels] + [ | |
f" <end_{l}>" for l in labels | |
] | |
self.language_alphabet = language_alphabet | |
self.trie = datrie.Trie(self.language_alphabet) | |
def print_internals(self): | |
""" | |
Prints KBTrie internals | |
""" | |
s = f"KBTrie:\n" | |
s += f" - contains {len(self.trie.items())} entries\n" | |
s += f" - encodes the following alphabet: '{self.language_alphabet}'\n\n" | |
s += "Showing entries in the following format: {Entity}:{Type}\n\n" | |
s += "\n".join([f"-- {k}: {v}" for k, v in self.trie.items()]) | |
return s | |
def __str__(self): | |
""" | |
Prints KBTrie internals | |
""" | |
return self.print_internals() | |
def __repr__(self): | |
""" | |
Prints KBTrie internals | |
""" | |
return self.print_internals() | |
def remove_entries(self, keys: list): | |
"""Remove KB entries from a list of key""" | |
for k in keys: | |
self.trie.pop(k) | |
def _is_valid_term(self, text: str, entity: str): | |
"""Add any conditions, like should the entity be a URL, or contain weird symbols, etc""" | |
return True | |
def _preprocess_text(self, text: str, entity: str): | |
"""Apply any kind of processing even based on the entity type""" | |
return text.lower() # Example preprocessing | |
def add_entry(self, text: str, entity: str): | |
"""Add one KB entry""" | |
ent = entity.strip().lower() | |
if text not in self.trie: | |
if self._is_valid_term(text, ent): | |
preprocessed_text = self._preprocess_text(text, ent) | |
self.trie[text] = ent | |
def n_grams(self, seq: list[str], n: int = 1): | |
"""Returns an iterator over the n-grams given a list of tokens""" | |
shift_token = lambda i: ((j, el) for j, el in enumerate(seq) if j >= i) | |
shifted_tokens = (shift_token(i) for i in range(n)) | |
tuple_n_grams = zip(*shifted_tokens) | |
return tuple_n_grams | |
def range_ngrams( | |
self, list_tokens: list[str], ngram_range: tuple[int, int] = (1, 2) | |
): | |
"""Returns an iterator over all n-grams for n in range(ngram_range) given a list_tokens.""" | |
return chain(*(self.n_grams(list_tokens, i) for i in range(*ngram_range))) | |
def longest_prefix(self, ngram): | |
"""Get the longest prefix matching the ngram from the Trie""" | |
# Find the longest matching prefix entity to the ngram | |
found = self.trie.longest_prefix_item(ngram[1], default=("", None)) | |
# The score could use levenshtein distance or anything else, here we use exact match for the sake of a simple example | |
# It is possible to implement fuzzy matching so the KB is robust to typos and more | |
# A more robust approach would be to use a Levenshtein automaton when constructing the trie | |
score = ngram[1] == found[0] | |
return (ngram, score, found) | |
def find_matching_ngrams(self, tx_filtered: str, ngram_range: tuple[int, int]): | |
"""Find n-grams that match items recovered from the KB Trie""" | |
# Use your own tokenization here | |
tx_ngrams_n = [ | |
(ngrams[0][0], " ".join([n for _, n in ngrams])) | |
for ngrams in self.range_ngrams(tx_filtered, ngram_range=ngram_range) | |
][::-1] | |
found_elems = sorted( | |
filter( | |
lambda r: r[1], | |
map(self.longest_prefix, tx_ngrams_n), | |
), | |
# Sort by: | |
# - number of n-grams | |
# - match score (in this example it's only 1 or 0 since exact matching is used) | |
# - len in letters of the word | |
key=lambda x: (x[0][0], x[1], -len(x[0][1])), | |
) | |
return found_elems | |
def inject_hints_to_text( | |
self, | |
tx: str, | |
ngram_range: tuple[int, int] = (1, 4), | |
): | |
"""This adds the final hints to a piece of text""" | |
tx_filtered = tx.split() | |
found_elems = self.find_matching_ngrams(tx_filtered, ngram_range) | |
items_already_seen = set() | |
for found_items in found_elems: | |
index_item, score, item_label_found = found_items | |
if index_item[0] in items_already_seen or not score: | |
continue | |
index_query, item_query = index_item | |
len_query = len(item_query.split()) | |
end_index = index_query + len(item_query.split()) - 1 | |
for i in range(len_query): | |
items_already_seen.add(index_item[0] + i) | |
item_found, label_found = item_label_found | |
# Example: | |
# [BEG_TOKEN] uber eats [END_TOKEN] | |
# Note: you can use any token you want here | |
# I'm using: <beg_X> and <end_X> where X will be org, loc, per, date, etc | |
tx_filtered[index_query] = f"<beg_{label_found}> {tx_filtered[index_query]}" | |
tx_filtered[end_index] = f"{tx_filtered[end_index]} <end_{label_found}>" | |
return " ".join(tx_filtered) | |
def main(): | |
kb = KBTrie() | |
# Output: | |
# KBTrie: | |
# - contains 0 entries | |
# - encodes the following alphabet: '0123456789abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ!"#$%&'()*+,-./:;<=>?@[\]^_`{|}~ | |
# ' | |
# Showing entries in the following format: {Entity}:{Type} | |
print(kb) | |
kb.add_entry("paypal", "org") | |
kb.add_entry("google", "org") | |
kb.add_entry("new york", "loc") | |
kb.add_entry("new york us", "loc") | |
kb.add_entry("new york usa", "loc") | |
# Output: | |
# KBTrie: | |
# - contains 5 entries | |
# - encodes the following alphabet: '0123456789abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ!"#$%&'()*+,-./:;<=>?@[\]^_`{|}~ | |
# ' | |
# Showing entries in the following format: {Entity}:{Type} | |
# -- google: org | |
# -- new york: loc | |
# -- new york us: loc | |
# -- new york usa: loc | |
# -- paypal: org | |
print(kb) | |
# Output: | |
# <beg_org> paypal <end_org> * <beg_org> google <end_org> <beg_loc> new york us <end_loc> | |
# Timing: 44.8 µs ± 535 ns per loop (mean ± std. dev. of 7 runs, 10,000 loops each) | |
print(kb.inject_hints_to_text("paypal * google new york us")) | |
# Output: | |
# padassdaypal paypal* <beg_org> paypal <end_org> <beg_org> google <end_org> <beg_org> google <end_org> <beg_org> paypal <end_org> * <beg_org> google <end_org> <beg_loc> new york us <end_loc> us us "us | |
# Timing: 92.8 µs ± 1.25 µs per loop (mean ± std. dev. of 7 runs, 10,000 loops each) | |
print( | |
kb.inject_hints_to_text( | |
'padassdaypal paypal* paypal google google paypal * google new york us us us "us' | |
) | |
) | |
if __name__ == "__main__": | |
main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment