Skip to content

Instantly share code, notes, and snippets.

@cpcdoy
Last active May 19, 2023 10:04
Show Gist options
  • Save cpcdoy/37b73944b605f4b51b442a7b1cc84a87 to your computer and use it in GitHub Desktop.
Save cpcdoy/37b73944b605f4b51b442a7b1cc84a87 to your computer and use it in GitHub Desktop.
KB Trie example implementation using datrie
import string
from itertools import chain, islice
import datrie # type: ignore
class KBTrie:
"""
Knowledge Base with a Trie structure.
Does consecutive n-gram decomposition of a text and then uses the Trie to look-up.
"""
def __init__(
self,
language_alphabet=string.printable,
labels=None,
):
"""
Init KBTrie
"""
if not labels:
labels = ["org", "loc", "per", "date"]
self.special_tokens = [f"<beg_{l}> " for l in labels] + [
f" <end_{l}>" for l in labels
]
self.language_alphabet = language_alphabet
self.trie = datrie.Trie(self.language_alphabet)
def print_internals(self):
"""
Prints KBTrie internals
"""
s = f"KBTrie:\n"
s += f" - contains {len(self.trie.items())} entries\n"
s += f" - encodes the following alphabet: '{self.language_alphabet}'\n\n"
s += "Showing entries in the following format: {Entity}:{Type}\n\n"
s += "\n".join([f"-- {k}: {v}" for k, v in self.trie.items()])
return s
def __str__(self):
"""
Prints KBTrie internals
"""
return self.print_internals()
def __repr__(self):
"""
Prints KBTrie internals
"""
return self.print_internals()
def remove_entries(self, keys: list):
"""Remove KB entries from a list of key"""
for k in keys:
self.trie.pop(k)
def _is_valid_term(self, text: str, entity: str):
"""Add any conditions, like should the entity be a URL, or contain weird symbols, etc"""
return True
def _preprocess_text(self, text: str, entity: str):
"""Apply any kind of processing even based on the entity type"""
return text.lower() # Example preprocessing
def add_entry(self, text: str, entity: str):
"""Add one KB entry"""
ent = entity.strip().lower()
if text not in self.trie:
if self._is_valid_term(text, ent):
preprocessed_text = self._preprocess_text(text, ent)
self.trie[text] = ent
def n_grams(self, seq: list[str], n: int = 1):
"""Returns an iterator over the n-grams given a list of tokens"""
shift_token = lambda i: ((j, el) for j, el in enumerate(seq) if j >= i)
shifted_tokens = (shift_token(i) for i in range(n))
tuple_n_grams = zip(*shifted_tokens)
return tuple_n_grams
def range_ngrams(
self, list_tokens: list[str], ngram_range: tuple[int, int] = (1, 2)
):
"""Returns an iterator over all n-grams for n in range(ngram_range) given a list_tokens."""
return chain(*(self.n_grams(list_tokens, i) for i in range(*ngram_range)))
def longest_prefix(self, ngram):
"""Get the longest prefix matching the ngram from the Trie"""
# Find the longest matching prefix entity to the ngram
found = self.trie.longest_prefix_item(ngram[1], default=("", None))
# The score could use levenshtein distance or anything else, here we use exact match for the sake of a simple example
# It is possible to implement fuzzy matching so the KB is robust to typos and more
# A more robust approach would be to use a Levenshtein automaton when constructing the trie
score = ngram[1] == found[0]
return (ngram, score, found)
def find_matching_ngrams(self, tx_filtered: str, ngram_range: tuple[int, int]):
"""Find n-grams that match items recovered from the KB Trie"""
# Use your own tokenization here
tx_ngrams_n = [
(ngrams[0][0], " ".join([n for _, n in ngrams]))
for ngrams in self.range_ngrams(tx_filtered, ngram_range=ngram_range)
][::-1]
found_elems = sorted(
filter(
lambda r: r[1],
map(self.longest_prefix, tx_ngrams_n),
),
# Sort by:
# - number of n-grams
# - match score (in this example it's only 1 or 0 since exact matching is used)
# - len in letters of the word
key=lambda x: (x[0][0], x[1], -len(x[0][1])),
)
return found_elems
def inject_hints_to_text(
self,
tx: str,
ngram_range: tuple[int, int] = (1, 4),
):
"""This adds the final hints to a piece of text"""
tx_filtered = tx.split()
found_elems = self.find_matching_ngrams(tx_filtered, ngram_range)
items_already_seen = set()
for found_items in found_elems:
index_item, score, item_label_found = found_items
if index_item[0] in items_already_seen or not score:
continue
index_query, item_query = index_item
len_query = len(item_query.split())
end_index = index_query + len(item_query.split()) - 1
for i in range(len_query):
items_already_seen.add(index_item[0] + i)
item_found, label_found = item_label_found
# Example:
# [BEG_TOKEN] uber eats [END_TOKEN]
# Note: you can use any token you want here
# I'm using: <beg_X> and <end_X> where X will be org, loc, per, date, etc
tx_filtered[index_query] = f"<beg_{label_found}> {tx_filtered[index_query]}"
tx_filtered[end_index] = f"{tx_filtered[end_index]} <end_{label_found}>"
return " ".join(tx_filtered)
def main():
kb = KBTrie()
# Output:
# KBTrie:
# - contains 0 entries
# - encodes the following alphabet: '0123456789abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ!"#$%&'()*+,-./:;<=>?@[\]^_`{|}~
# '
# Showing entries in the following format: {Entity}:{Type}
print(kb)
kb.add_entry("paypal", "org")
kb.add_entry("google", "org")
kb.add_entry("new york", "loc")
kb.add_entry("new york us", "loc")
kb.add_entry("new york usa", "loc")
# Output:
# KBTrie:
# - contains 5 entries
# - encodes the following alphabet: '0123456789abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ!"#$%&'()*+,-./:;<=>?@[\]^_`{|}~
# '
# Showing entries in the following format: {Entity}:{Type}
# -- google: org
# -- new york: loc
# -- new york us: loc
# -- new york usa: loc
# -- paypal: org
print(kb)
# Output:
# <beg_org> paypal <end_org> * <beg_org> google <end_org> <beg_loc> new york us <end_loc>
# Timing: 44.8 µs ± 535 ns per loop (mean ± std. dev. of 7 runs, 10,000 loops each)
print(kb.inject_hints_to_text("paypal * google new york us"))
# Output:
# padassdaypal paypal* <beg_org> paypal <end_org> <beg_org> google <end_org> <beg_org> google <end_org> <beg_org> paypal <end_org> * <beg_org> google <end_org> <beg_loc> new york us <end_loc> us us "us
# Timing: 92.8 µs ± 1.25 µs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)
print(
kb.inject_hints_to_text(
'padassdaypal paypal* paypal google google paypal * google new york us us us "us'
)
)
if __name__ == "__main__":
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment