Skip to content

Instantly share code, notes, and snippets.

@saharNooby
Last active May 5, 2023 18:57
Show Gist options
  • Star 2 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save saharNooby/bb54519a7d3735afb6949825608c00f0 to your computer and use it in GitHub Desktop.
Save saharNooby/bb54519a7d3735afb6949825608c00f0 to your computer and use it in GitHub Desktop.
Probably the dumbest, no-dependencies, pure Python implementation of 20B_tokenizer.json (a BPE tokenizer for GPT-NeoX model)
import regex
import json
import unicodedata
from typing import Tuple, Callable, Union
# Parses the tokenizer config and returns encode and decode functions.
def load_tokenizer(config_path: str) -> Tuple[Callable[[str], list[int]], Callable[[list[int]], str]]:
# Maps any byte 0..255 to a printable Unicode character.
byte_to_unicode: dict[int, str] = {
33: "!",
34: "\"",
35: "#",
36: "$",
37: "%",
38: "&",
39: "\'",
40: "(",
41: ")",
42: "*",
43: "+",
44: ",",
45: "-",
46: ".",
47: "/",
48: "0",
49: "1",
50: "2",
51: "3",
52: "4",
53: "5",
54: "6",
55: "7",
56: "8",
57: "9",
58: ":",
59: ";",
60: "<",
61: "=",
62: ">",
63: "?",
64: "@",
65: "A",
66: "B",
67: "C",
68: "D",
69: "E",
70: "F",
71: "G",
72: "H",
73: "I",
74: "J",
75: "K",
76: "L",
77: "M",
78: "N",
79: "O",
80: "P",
81: "Q",
82: "R",
83: "S",
84: "T",
85: "U",
86: "V",
87: "W",
88: "X",
89: "Y",
90: "Z",
91: "[",
92: "\\",
93: "]",
94: "^",
95: "_",
96: "`",
97: "a",
98: "b",
99: "c",
100: "d",
101: "e",
102: "f",
103: "g",
104: "h",
105: "i",
106: "j",
107: "k",
108: "l",
109: "m",
110: "n",
111: "o",
112: "p",
113: "q",
114: "r",
115: "s",
116: "t",
117: "u",
118: "v",
119: "w",
120: "x",
121: "y",
122: "z",
123: "{",
124: "|",
125: "}",
126: "~",
161: "¡",
162: "¢",
163: "£",
164: "¤",
165: "¥",
166: "¦",
167: "§",
168: "¨",
169: "©",
170: "ª",
171: "«",
172: "¬",
174: "®",
175: "¯",
176: "°",
177: "±",
178: "²",
179: "³",
180: "´",
181: "µ",
182: "¶",
183: "·",
184: "¸",
185: "¹",
186: "º",
187: "»",
188: "¼",
189: "½",
190: "¾",
191: "¿",
192: "À",
193: "Á",
194: "Â",
195: "Ã",
196: "Ä",
197: "Å",
198: "Æ",
199: "Ç",
200: "È",
201: "É",
202: "Ê",
203: "Ë",
204: "Ì",
205: "Í",
206: "Î",
207: "Ï",
208: "Ð",
209: "Ñ",
210: "Ò",
211: "Ó",
212: "Ô",
213: "Õ",
214: "Ö",
215: "×",
216: "Ø",
217: "Ù",
218: "Ú",
219: "Û",
220: "Ü",
221: "Ý",
222: "Þ",
223: "ß",
224: "à",
225: "á",
226: "â",
227: "ã",
228: "ä",
229: "å",
230: "æ",
231: "ç",
232: "è",
233: "é",
234: "ê",
235: "ë",
236: "ì",
237: "í",
238: "î",
239: "ï",
240: "ð",
241: "ñ",
242: "ò",
243: "ó",
244: "ô",
245: "õ",
246: "ö",
247: "÷",
248: "ø",
249: "ù",
250: "ú",
251: "û",
252: "ü",
253: "ý",
254: "þ",
255: "ÿ",
0: "Ā",
1: "ā",
2: "Ă",
3: "ă",
4: "Ą",
5: "ą",
6: "Ć",
7: "ć",
8: "Ĉ",
9: "ĉ",
10: "Ċ",
11: "ċ",
12: "Č",
13: "č",
14: "Ď",
15: "ď",
16: "Đ",
17: "đ",
18: "Ē",
19: "ē",
20: "Ĕ",
21: "ĕ",
22: "Ė",
23: "ė",
24: "Ę",
25: "ę",
26: "Ě",
27: "ě",
28: "Ĝ",
29: "ĝ",
30: "Ğ",
31: "ğ",
32: "Ġ",
127: "ġ",
128: "Ģ",
129: "ģ",
130: "Ĥ",
131: "ĥ",
132: "Ħ",
133: "ħ",
134: "Ĩ",
135: "ĩ",
136: "Ī",
137: "ī",
138: "Ĭ",
139: "ĭ",
140: "Į",
141: "į",
142: "İ",
143: "ı",
144: "IJ",
145: "ij",
146: "Ĵ",
147: "ĵ",
148: "Ķ",
149: "ķ",
150: "ĸ",
151: "Ĺ",
152: "ĺ",
153: "Ļ",
154: "ļ",
155: "Ľ",
156: "ľ",
157: "Ŀ",
158: "ŀ",
159: "Ł",
160: "ł",
173: "Ń"
}
# Reverse of byte_to_unicode.
unicode_to_bytes: dict[str, int] = {byte_to_unicode[b]: b for b in byte_to_unicode.keys()}
with open(config_path, 'r') as f:
config = json.loads(f.read())
vocab: dict[str, int] = config['model']['vocab']
merges: list[str] = config['model']['merges']
added_tokens: dict[int, str] = {t['id']: t['content'] for t in config['added_tokens']}
for added_token_id in added_tokens:
encoded_added_token = ''.join([byte_to_unicode[b] for b in added_tokens[added_token_id].encode('utf-8')])
vocab[encoded_added_token] = added_token_id
vocab_reversed: dict[int, str] = {vocab[t]: t for t in vocab.keys()}
def replace_subsequence(lst: list, a: list, b: list) -> None:
for i in range(len(lst)):
if lst[i:i + len(a)] == a:
lst[i:i + len(a)] = b
def split_words(s: str) -> list[str]:
result: list[str] = []
pattern: regex.Pattern = regex.compile(r"'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+")
for m in regex.finditer(pattern, s):
result.append(m.group())
return result
def encode_added_tokens(s: str) -> list[Union[int, str]]:
result: list[Union[int, str]] = []
remainder: str = s
while len(remainder) > 0:
nearest_pos = len(remainder)
nearest_token = -1
for added_token_id in added_tokens:
pos = remainder.find(added_tokens[added_token_id])
if pos != -1 and pos < nearest_pos:
nearest_pos = pos
nearest_token = added_token_id
if nearest_pos == len(remainder):
result.append(remainder)
break
if nearest_pos != 0:
result.append(remainder[:nearest_pos])
result.append(nearest_token)
remainder = remainder[nearest_pos + len(added_tokens[nearest_token]):]
return result
# Converts a string to a list of tokens.
def encode(s: str) -> list[int]:
s = unicodedata.normalize('NFC', s)
result: list[int] = []
for part in encode_added_tokens(s):
if type(part) == int:
result.append(part)
continue
for word in split_words(part):
tokens = [vocab[byte_to_unicode[b]] for b in word.encode('utf-8')]
for added_token_id in added_tokens.keys():
added_token_tokens = [vocab[byte_to_unicode[b]] for b in added_tokens[added_token_id].encode('utf-8')]
replace_subsequence(tokens, added_token_tokens, [added_token_id])
for merge in merges:
space = merge.find(' ')
assert space != -1
token_a = vocab[merge[0:space]]
token_b = vocab[merge[space + 1:]]
token_merged = vocab[merge[0:space] + merge[space + 1:]]
for i in range(len(tokens) - 1):
if i + 1 < len(tokens) and tokens[i] == token_a and tokens[i + 1] == token_b:
# Replace and shift
tokens[i] = token_merged
tokens[i + 1:] = tokens[i + 2:]
result += tokens
return result
# Converts list of tokens to a string.
def decode(tokens: list[int]) -> str:
result = bytes()
for token in tokens:
result += bytes([unicode_to_bytes[c] for c in vocab_reversed[token]])
return result.decode('utf-8')
return encode, decode
# Code below is to test correctness of the tokenizer.
# It may safely be removed.
def test() -> None:
config_path = r"./20B_tokenizer.json"
encode, decode = load_tokenizer(config_path)
from tokenizers import Tokenizer
tokenizer = Tokenizer.from_file(config_path)
# ---
test_strings = [
'\n a',
# An ambigious edge case, should tokenize into ["\n", " ~"], not ["\n ", "~"]
# This test will fail unless tokenizer splits words by the regex above
'\n ~',
'\n \u597d',
# Special tokens
'-> <|endoftext|><|padding|> int',
# Just some Unicode stuff
'I\'ll \'d test блабла 以下は、]) -> <|endoftext|><|padding|> int',
# This test will fail unless tokenizer does normalization to NFC
"κόσμε"
]
import random
r = random.Random(42)
for i in range(256):
test_strings += [' ' * i]
for i in range(256):
x = chr(r.randrange(0, 256))
x = x * r.randrange(1, 32)
try:
x.encode('utf-8')
test_strings += [x]
except:
pass
for i in range(256):
x = chr(r.randrange(0, 1114112))
x = x * r.randrange(1, 4)
try:
x.encode('utf-8')
test_strings += [x]
except:
pass
for test_string in test_strings:
print()
print(json.dumps(test_string))
encoded_expected = tokenizer.encode(test_string).ids
print('expect', encoded_expected)
encoded_actual = encode(test_string)
print('actual', encoded_actual)
assert str(encoded_expected) == str(encoded_actual)
decoded_actual = decode(encoded_actual)
print(json.dumps(decoded_actual))
assert unicodedata.normalize('NFC', test_string) == decoded_actual
if __name__ == '__main__':
test()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment