Skip to content

Instantly share code, notes, and snippets.

@opparco
Last active August 20, 2023 10:20
Show Gist options
  • Save opparco/852864626893447ac67e34c0d1cde829 to your computer and use it in GitHub Desktop.
Save opparco/852864626893447ac67e34c0d1cde829 to your computer and use it in GitHub Desktop.
correct vocab of matsuo-lab/weblab-10b
#
# correct vocab of matsuo-lab/weblab-10b
#
vocab = {}
# 95 -> (none)
# 96 -> \xa1
# ...
# 107 -> \xac
# 244 -> \xad
# 108 -> \xae
# 109 -> \xaf
# 110 -> \xb0
# ...
# 125 -> \xbf
# 126 -> (none)
# 158 -> (none)
# 159 -> 0xe3
# ...
# 165 -> 0xe9
# 166 -> (none)
from sys import byteorder
off = 96 - 0xa1
for c in range(0xa1, 0xac + 1):
id = c + off
vocab[id] = c.to_bytes(1, byteorder)
off = 108 - 0xae
for c in range(0xae, 0xbf + 1):
id = c + off
vocab[id] = c.to_bytes(1, byteorder)
off = 159 - 0xe3
for c in range(0xe3, 0xe9 + 1):
id = c + off
vocab[id] = c.to_bytes(1, byteorder)
#
# 長音記号 + カタカナの切れ端
vocab[12448] = b'\xe3\x83\xbc\xe3\x83'
# ン + カタカナの切れ端
vocab[25558] = b'\xe3\x83\xb3\xe3\x83'
# シ + カタカナの切れ端
vocab[35259] = b'\xe3\x82\xb7\xe3\x83'
# フ + 拗音の切れ端
vocab[38892] = b'\xe3\x83\x95\xe3\x82'
#
from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained("matsuo-lab/weblab-10b")
plane_map = {
"cjk_symbols_and_punctuation": range(0x3000, 0x303F + 1),
"hiragana": range(0x3040, 0x309F + 1),
"katakana": range(0x30A0, 0x30FF + 1),
"cjk_unified_ideographs": range(0x4E00, 0x9FFF + 1),
# "cjk_unified_ideographs_extension_a": range(0x3400, 0x4DBF + 1),
"halfwidth_and_fullwidth_forms": range(0xFF00, 0xFFEF + 1),
}
# o = open("vocab.txt", "w", encoding="utf-8")
# o.write("cp decoded [len] u8 [len] id, ...\n")
for name in plane_map:
cp_range = plane_map[name]
for cp in cp_range:
ids = tokenizer.encode(chr(cp), add_special_tokens=False)
decoded = tokenizer.decode(ids)
u8 = decoded.encode('utf-8')
# o.write("U+{:x} {} [{:d}] {} [{:d}] {}\n".format(cp, decoded, len(u8), u8, len(ids), ', '.join(map(str, ids))))
if len(ids) == 1:
id_0, = ids
if id_0 in vocab:
mb_0 = vocab[id_0]
if mb_0 != u8:
print("error: vocab mismatch. {:d} [{:d}] {} != {}".format(id_0, len(mb_0), mb_0, u8))
else:
# correct
vocab[id_0] = u8
elif len(ids) == 2 and len(u8) == 3:
id_0, id_1 = ids
if id_0 in vocab:
mb_0 = vocab[id_0]
if mb_0 != u8[0:len(mb_0)]:
print("error: vocab mismatch. {:d} [{:d}] {} != {}".format(id_0, len(mb_0), mb_0, u8[0:len(mb_0)]))
mb_1 = vocab[id_1] = u8[len(mb_0):3]
elif id_1 in vocab:
mb_1 = vocab[id_1]
if mb_1 != u8[3-len(mb_1):3]:
print("error: vocab mismatch. {:d} [{:d}] {} != {}".format(id_1, len(mb_1), mb_1, u8[3-len(mb_1):3]))
mb_0 = vocab[id_0] = u8[0:3-len(mb_1)]
else:
# correct ?
vocab[id_0] = u8[0:2]
vocab[id_1] = u8[2:3]
elif len(ids) == 3 and len(u8) == 3:
id_0, id_1, id_2 = ids
if id_0 in vocab:
mb_0 = vocab[id_0]
if mb_0 != u8[0:1]:
print("error: vocab mismatch. {:d} [{:d}] {} != {}".format(id_0, len(mb_0), mb_0, u8[0:1]))
if id_1 in vocab:
mb_1 = vocab[id_1]
if mb_1 != u8[1:2]:
print("error: vocab mismatch. {:d} [{:d}] {} != {}".format(id_1, len(mb_1), mb_1, u8[1:2]))
if id_2 in vocab:
mb_2 = vocab[id_2]
if mb_2 != u8[2:3]:
print("error: vocab mismatch. {:d} [{:d}] {} != {}".format(id_2, len(mb_2), mb_2, u8[2:3]))
# vocab[id_0] = u8[0:1]
# vocab[id_1] = u8[1:2]
# vocab[id_2] = u8[2:3]
else:
print("error: U+{:x} len(u8) {:d} num-ids {:d}".format(cp, len(u8), len(ids)))
# o.close()
if __name__ == "__main__":
print(vocab)
@opparco
Copy link
Author

opparco commented Aug 19, 2023

usage

from weblab_10b import vocab

if i in vocab:
    text = vocab[i]
else:
    text = tokenizer.decode([i]).encode('utf-8')

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment