-
-
Save BobbyL2k/573523ca3cdc9b52e9dfdad1423fb563 to your computer and use it in GitHub Desktop.
KLabs Trie tokenizer
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Copyright 2023 KLabs Co., Ltd. | |
# | |
# Permission is hereby granted, free of charge, to any person obtaining a | |
# copy of this software and associated documentation files (the "Software"), | |
# to deal in the Software without restriction, including without limitation | |
# the rights to use, copy, modify, merge, publish, distribute, sublicense, | |
# and/or sell copies of the Software, and to permit persons to whom the | |
# Software is furnished to do so, subject to the following conditions: | |
# | |
# The above copyright notice and this permission notice shall be included in | |
# all copies or substantial portions of the Software. | |
# | |
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR | |
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, | |
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE | |
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER | |
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING | |
# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER | |
# DEALINGS IN THE SOFTWARE. | |
# Filename: klabs_nlp/trie.py | |
from __future__ import annotations | |
from typing import Dict, Iterable, List, NamedTuple, Optional, Tuple | |
class Tag(NamedTuple): | |
# begining position of the tag | |
begin: int | |
# ending position of the tag | |
end: int | |
# token | |
token: str | |
class _TrieNode: | |
def __init__( | |
self, is_valid: bool = False, branch: Optional[Dict[str, _TrieNode]] = None | |
): | |
self.is_valid = is_valid | |
if branch is None: | |
branch = {} | |
self._branch: Dict[str, _TrieNode] = branch | |
def get_branch(self, char: str) -> _TrieNode: | |
if char not in self._branch: | |
new_node = _TrieNode() | |
self._branch[char] = new_node | |
return new_node | |
return self._branch[char] | |
def add(self, chars: str) -> None: | |
if chars == "": | |
self.is_valid = True | |
else: | |
self.get_branch(chars[0]).add(chars[1:]) | |
def to_string(self, indent: int = 0) -> str: | |
res = "" | |
ind = " " * indent | |
for key, node in self._branch.items(): | |
res += f"{ind}{key}:" | |
if node.is_valid: | |
res += "*\n" | |
else: | |
res += "\n" | |
res += node.to_string(indent + 1) | |
return res | |
class Trie: | |
def __init__(self, vocab: Iterable[str]): | |
self.root = _TrieNode() | |
for token in sorted(vocab): | |
self.root.add(token) | |
def match_once_str(self, string: str) -> List[str]: | |
"""match once (one token) and return tokens matched | |
Args: | |
string (str): string to match | |
Returns: | |
List[str]: tokens matched | |
""" | |
res = [] | |
node = self.root | |
for idx, char in enumerate(string): | |
if char not in node._branch: | |
break | |
node = node._branch[char] | |
if node.is_valid: | |
res.append(string[: idx + 1]) | |
return res | |
def match_once_end(self, string: str) -> List[int]: | |
"""match once (one token) and return end-point of tokens matched | |
Args: | |
string (str): string to match | |
Returns: | |
List[int]: end points of tokens matched | |
""" | |
res: List[int] = [] | |
node = self.root | |
for idx, char in enumerate(string): | |
if char not in node._branch: | |
break | |
node = node._branch[char] | |
if node.is_valid: | |
res.append(idx + 1) | |
return res | |
def attempt_tokenization_str(self, line: str) -> List[str]: | |
"""greedy tokenize string | |
For example | |
Trie, with dict of ["aaa"] tokenizing a string "baaabaaa" | |
will result in an output of ["b", "aaa", "b", "aaa"] | |
Args: | |
line (str): string to tokenize | |
Returns: | |
List[str]: tokens | |
""" | |
text = line | |
tokens: List[str] = [] | |
unk = "" | |
while len(text) > 0: | |
candidates = self.match_once_str(text) | |
if len(candidates) > 0: | |
best = candidates[0] | |
for c in candidates[1:]: | |
if len(c) > len(best): | |
best = c | |
if unk != "": | |
tokens.append(unk) | |
unk = "" | |
tokens.append(best) | |
text = text[len(best) :] | |
else: | |
unk += text[0] | |
text = text[1:] | |
if unk != "": | |
tokens.append(unk) | |
return tokens | |
def attempt_tokenization_bound(self, line: str) -> List[Tuple[int, int]]: | |
"""attempt greedy tokenize and produce bounds for each **known** token detected | |
"attempt" means that input does not have to consist of only known tokens. | |
For example | |
Trie, with dict of ["aaa"] tokenizing a string "baaabaaa" | |
will result in an output of [(1, 4), (5, 8)] | |
Args: | |
line (str): string to tokenize | |
Returns: | |
List[Tuple[int, int]]: bounds | |
""" | |
token_boundaries: List[Tuple[int, int]] = [] | |
begin = 0 | |
while begin < len(line): | |
candidates = self.match_once_end(line[begin:]) | |
if len(candidates) > 0: | |
best = max(candidates) | |
token_boundaries.append((begin, begin + best)) | |
begin += best | |
else: | |
begin += 1 | |
return token_boundaries | |
def attempt_maximal_tokenization_str(self, line: str) -> List[str]: | |
res: List[str] = [] | |
last_pos = 0 | |
for begin, end in self.attempt_maximal_tokenization_bound(line): | |
if last_pos != begin: | |
# add OOV tokens into result | |
res.append(line[last_pos:begin]) | |
res.append(line[begin:end]) | |
last_pos = end | |
if last_pos != len(line): | |
res.append(line[last_pos:]) | |
return res | |
def attempt_maximal_tokenization_bound(self, line: str) -> List[Tuple[int, int]]: | |
return [ | |
(tag.begin, tag.end) for tag in self.attempt_maximal_tokenization_tag(line) | |
] | |
def attempt_maximal_tokenization_tag(self, line: str) -> List[Tag]: | |
"""attempt tokenization for maximal token coverage | |
"attempt" means that input does not have to consist of only known tokens. | |
NOTE: Unknown tokens are not tagged | |
Args: | |
line (str): line to tokenize | |
Returns: | |
List[Tag]: tokenization of known tokens | |
""" | |
# Dynamic programming: | |
# iterate backward from end of string to find the optimal path with maximal coverage. | |
# each position in the DP table is optimizing for maximal coverage | |
# where dp[pos] represent the maximal coverage that could be achieved at `pos`. | |
# DP table for each position in text | |
# Tuple of | |
# coverage: maximum coverage of current path | |
# step: step size to next position | |
# either step == size of token, if at a start position of a token | |
# or step == offset to next start position of a token | |
# token: token current position is the start of | |
# either token is `str`, if at a start position of a token | |
# or token is `None` | |
dp_table: List[Tuple[int, int, Optional[str]]] = [ | |
(0, 0, None) | |
# Extend pass the end of line by 1 | |
for _ in range(len(line) + 1) | |
] | |
next_pos = dp_table[-1] # index `-1` is same as `len(line)` | |
# loop from [len(line) -1, 0] | |
for pos in range(len(line) - 1, -1, -1): | |
# mark initial best solution as no token | |
# that is | |
# same coverage as next position | |
# jump to next position that is not None (has valid token) | |
# mark as no token | |
if next_pos[2] is None: | |
best_solution = (next_pos[0], next_pos[1] + 1, None) | |
else: | |
best_solution = (next_pos[0], 1, None) | |
# find possible paths | |
candidates = self.match_once_end(line[pos:]) | |
for cover in candidates: | |
# path to take if use this candidate | |
path = dp_table[pos + cover] | |
# new coverage if taken this path | |
new_cover = path[0] + cover | |
if ( | |
# higher coverage | |
best_solution[0] < new_cover | |
or ( | |
# same coverage and longer coverage at this position. | |
# Since we prefer longer leading tokens | |
best_solution[0] == new_cover | |
and best_solution[1] < cover | |
) | |
): | |
best_solution = (new_cover, cover, line[pos : pos + cover]) | |
next_pos = dp_table[pos] = best_solution | |
# forward pass on DP table to get tags | |
tags: List[Tag] = [] | |
pos: int = 0 | |
coverage, step, token = dp_table[pos] | |
while coverage > 0: | |
if token is not None: | |
tags.append(Tag(pos, pos + step, token)) | |
pos += step | |
coverage, step, token = dp_table[pos] | |
return tags |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment