Skip to content

Instantly share code, notes, and snippets.

@BobbyL2k

BobbyL2k/trie.py Secret

Created October 3, 2023 06:15
Show Gist options
  • Save BobbyL2k/573523ca3cdc9b52e9dfdad1423fb563 to your computer and use it in GitHub Desktop.
Save BobbyL2k/573523ca3cdc9b52e9dfdad1423fb563 to your computer and use it in GitHub Desktop.
KLabs Trie tokenizer
# Copyright 2023 KLabs Co., Ltd.
#
# Permission is hereby granted, free of charge, to any person obtaining a
# copy of this software and associated documentation files (the "Software"),
# to deal in the Software without restriction, including without limitation
# the rights to use, copy, modify, merge, publish, distribute, sublicense,
# and/or sell copies of the Software, and to permit persons to whom the
# Software is furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in
# all copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
# DEALINGS IN THE SOFTWARE.
# Filename: klabs_nlp/trie.py
from __future__ import annotations
from typing import Dict, Iterable, List, NamedTuple, Optional, Tuple
class Tag(NamedTuple):
# begining position of the tag
begin: int
# ending position of the tag
end: int
# token
token: str
class _TrieNode:
def __init__(
self, is_valid: bool = False, branch: Optional[Dict[str, _TrieNode]] = None
):
self.is_valid = is_valid
if branch is None:
branch = {}
self._branch: Dict[str, _TrieNode] = branch
def get_branch(self, char: str) -> _TrieNode:
if char not in self._branch:
new_node = _TrieNode()
self._branch[char] = new_node
return new_node
return self._branch[char]
def add(self, chars: str) -> None:
if chars == "":
self.is_valid = True
else:
self.get_branch(chars[0]).add(chars[1:])
def to_string(self, indent: int = 0) -> str:
res = ""
ind = " " * indent
for key, node in self._branch.items():
res += f"{ind}{key}:"
if node.is_valid:
res += "*\n"
else:
res += "\n"
res += node.to_string(indent + 1)
return res
class Trie:
def __init__(self, vocab: Iterable[str]):
self.root = _TrieNode()
for token in sorted(vocab):
self.root.add(token)
def match_once_str(self, string: str) -> List[str]:
"""match once (one token) and return tokens matched
Args:
string (str): string to match
Returns:
List[str]: tokens matched
"""
res = []
node = self.root
for idx, char in enumerate(string):
if char not in node._branch:
break
node = node._branch[char]
if node.is_valid:
res.append(string[: idx + 1])
return res
def match_once_end(self, string: str) -> List[int]:
"""match once (one token) and return end-point of tokens matched
Args:
string (str): string to match
Returns:
List[int]: end points of tokens matched
"""
res: List[int] = []
node = self.root
for idx, char in enumerate(string):
if char not in node._branch:
break
node = node._branch[char]
if node.is_valid:
res.append(idx + 1)
return res
def attempt_tokenization_str(self, line: str) -> List[str]:
"""greedy tokenize string
For example
Trie, with dict of ["aaa"] tokenizing a string "baaabaaa"
will result in an output of ["b", "aaa", "b", "aaa"]
Args:
line (str): string to tokenize
Returns:
List[str]: tokens
"""
text = line
tokens: List[str] = []
unk = ""
while len(text) > 0:
candidates = self.match_once_str(text)
if len(candidates) > 0:
best = candidates[0]
for c in candidates[1:]:
if len(c) > len(best):
best = c
if unk != "":
tokens.append(unk)
unk = ""
tokens.append(best)
text = text[len(best) :]
else:
unk += text[0]
text = text[1:]
if unk != "":
tokens.append(unk)
return tokens
def attempt_tokenization_bound(self, line: str) -> List[Tuple[int, int]]:
"""attempt greedy tokenize and produce bounds for each **known** token detected
"attempt" means that input does not have to consist of only known tokens.
For example
Trie, with dict of ["aaa"] tokenizing a string "baaabaaa"
will result in an output of [(1, 4), (5, 8)]
Args:
line (str): string to tokenize
Returns:
List[Tuple[int, int]]: bounds
"""
token_boundaries: List[Tuple[int, int]] = []
begin = 0
while begin < len(line):
candidates = self.match_once_end(line[begin:])
if len(candidates) > 0:
best = max(candidates)
token_boundaries.append((begin, begin + best))
begin += best
else:
begin += 1
return token_boundaries
def attempt_maximal_tokenization_str(self, line: str) -> List[str]:
res: List[str] = []
last_pos = 0
for begin, end in self.attempt_maximal_tokenization_bound(line):
if last_pos != begin:
# add OOV tokens into result
res.append(line[last_pos:begin])
res.append(line[begin:end])
last_pos = end
if last_pos != len(line):
res.append(line[last_pos:])
return res
def attempt_maximal_tokenization_bound(self, line: str) -> List[Tuple[int, int]]:
return [
(tag.begin, tag.end) for tag in self.attempt_maximal_tokenization_tag(line)
]
def attempt_maximal_tokenization_tag(self, line: str) -> List[Tag]:
"""attempt tokenization for maximal token coverage
"attempt" means that input does not have to consist of only known tokens.
NOTE: Unknown tokens are not tagged
Args:
line (str): line to tokenize
Returns:
List[Tag]: tokenization of known tokens
"""
# Dynamic programming:
# iterate backward from end of string to find the optimal path with maximal coverage.
# each position in the DP table is optimizing for maximal coverage
# where dp[pos] represent the maximal coverage that could be achieved at `pos`.
# DP table for each position in text
# Tuple of
# coverage: maximum coverage of current path
# step: step size to next position
# either step == size of token, if at a start position of a token
# or step == offset to next start position of a token
# token: token current position is the start of
# either token is `str`, if at a start position of a token
# or token is `None`
dp_table: List[Tuple[int, int, Optional[str]]] = [
(0, 0, None)
# Extend pass the end of line by 1
for _ in range(len(line) + 1)
]
next_pos = dp_table[-1] # index `-1` is same as `len(line)`
# loop from [len(line) -1, 0]
for pos in range(len(line) - 1, -1, -1):
# mark initial best solution as no token
# that is
# same coverage as next position
# jump to next position that is not None (has valid token)
# mark as no token
if next_pos[2] is None:
best_solution = (next_pos[0], next_pos[1] + 1, None)
else:
best_solution = (next_pos[0], 1, None)
# find possible paths
candidates = self.match_once_end(line[pos:])
for cover in candidates:
# path to take if use this candidate
path = dp_table[pos + cover]
# new coverage if taken this path
new_cover = path[0] + cover
if (
# higher coverage
best_solution[0] < new_cover
or (
# same coverage and longer coverage at this position.
# Since we prefer longer leading tokens
best_solution[0] == new_cover
and best_solution[1] < cover
)
):
best_solution = (new_cover, cover, line[pos : pos + cover])
next_pos = dp_table[pos] = best_solution
# forward pass on DP table to get tags
tags: List[Tag] = []
pos: int = 0
coverage, step, token = dp_table[pos]
while coverage > 0:
if token is not None:
tags.append(Tag(pos, pos + step, token))
pos += step
coverage, step, token = dp_table[pos]
return tags
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment