Skip to content

Instantly share code, notes, and snippets.

@gautierdag
Created May 9, 2024 17:36
Show Gist options
  • Save gautierdag/f8c9b6a2157a06f3e95c0b25c8afa131 to your computer and use it in GitHub Desktop.
Save gautierdag/f8c9b6a2157a06f3e95c0b25c8afa131 to your computer and use it in GitHub Desktop.
Example class for token healing - use get_start_decoding to find idx to start constrained decoding over matched tokens
from typing import List, Any
import abc
from transformers import PreTrainedTokenizerFast
class BaseTokenizer(abc.ABC):
def __init__(self) -> None:
super().__init__()
self.bos_id = 1
self.eos_id = 2
self.pad_id = 3
self.n_words = 3
@abc.abstractmethod
def _encode(self, s: str) -> List[int]:
pass
def encode(self, s: str, bos: bool, eos: bool) -> List[int]:
assert isinstance(s, str)
t = self._encode(s)
if bos:
t.insert(0, self.bos_id)
if eos:
t.append(self.eos_id)
return t
def decode(self, tokens: List[int], cut_at_eos: bool = True) -> str:
if cut_at_eos:
for k, t in enumerate(tokens):
if t == self.eos_id:
tokens = tokens[: k + 1]
break
return self._decode(tokens)
@abc.abstractmethod
def _decode(self, tokens: List[int]) -> str:
pass
_token_prefix_map: Any = None
max_token_length: int = 16
def _build_token_prefix_map(self):
"""
Build a map from token to index using a Trie datastructure
Taken from Microsoft's guidance library:
https://github.com/guidance-ai/guidance/blob/23d0ba12720d09bb87b520d6c84462857f5dfcec/guidance/llms/_transformers.py#L74
"""
import pygtrie
token_map = pygtrie.CharTrie()
for i in range(self.n_words):
try:
s = self._decode([i])
self.max_token_length = max(self.max_token_length, len(s))
except:
print(f"token id {i} not found in tokenizer")
continue
if s in token_map:
token_map[s].append(i) # handle duplicate token encodings
else:
token_map[s] = [i]
return token_map
def prefix_matches(self, prefix: str) -> list[int]:
"""
Return the list of tokens ids that match the given prefix string.
Raises KeyError if the prefix is not found.
"""
if self._token_prefix_map is None:
self._token_prefix_map = self._build_token_prefix_map()
return [v for arr in self._token_prefix_map.values(prefix=prefix) for v in arr]
def get_start_decoding(self, prompt_tokens: list[int]) -> tuple[int, list[int]]:
"""
Given encoded tokens, return the index of the start of token healing
and the list of tokens that match the possible healing tokens.
This builds the possible healing tokens by taking the longest subsequence
that has matches, growing iteratively from the end of the prompt
up to the max token length.
"""
matches, subseq = ([], "")
i, out_index = len(prompt_tokens) - 1, len(prompt_tokens) - 1
while len(subseq) < self.max_token_length and i >= 0:
subseq = self.decode(prompt_tokens[i:])
try:
matches = self.prefix_matches(prefix=subseq)
out_index = i
except KeyError:
pass
i -= 1
return out_index, matches
class HuggingFaceTokenizer(BaseTokenizer):
def __init__(self, model_path: str):
self.tokenizer_model = PreTrainedTokenizerFast(
tokenizer_file=model_path,
clean_up_tokenization_spaces=False,
)
self.n_words = len(self.tokenizer_model)
self.bos_id = self.tokenizer_model.bos_token_id
self.eos_id = self.tokenizer_model.eos_token_id
self.pad_id: int = -1
def _encode(self, s: str) -> List[int]:
return self.tokenizer_model.encode(s)
def _decode(self, tokens: List[int]) -> str:
return self.tokenizer_model.decode(tokens, skip_special_tokens=False)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment