-
-
Save Ar4ikov/8b22ee3ef952140611510b17c2f3f000 to your computer and use it in GitHub Desktop.
Train BPE encoder by yourself
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import tiktoken | |
from tiktoken._educational import * | |
"""This is an educational implementation of the byte pair encoding algorithm.""" | |
import collections | |
from typing import Optional | |
import regex | |
import tiktoken | |
class SimpleBytePairEncoding: | |
def __init__(self, *, pat_str: str, mergeable_ranks: dict[bytes, int]) -> None: | |
"""Creates an Encoding object.""" | |
# A regex pattern string that is used to split the input text | |
self.pat_str = pat_str | |
# A dictionary mapping token bytes to their ranks. The ranks correspond to merge priority | |
self.mergeable_ranks = mergeable_ranks | |
self._decoder = {token: token_bytes for token_bytes, token in mergeable_ranks.items()} | |
self._pat = regex.compile(pat_str) | |
def encode(self, text: str, visualise: Optional[str] = "colour") -> list[int]: | |
"""Encodes a string into tokens. | |
>>> enc.encode("hello world") | |
[388, 372] | |
""" | |
# Use the regex to split the text into (approximately) words | |
words = self._pat.findall(text) | |
tokens = [] | |
for word in words: | |
# Turn each word into tokens, using the byte pair encoding algorithm | |
word_bytes = word.encode("utf-8") | |
word_tokens = bpe_encode(self.mergeable_ranks, word_bytes, visualise=visualise) | |
tokens.extend(word_tokens) | |
return tokens | |
def decode_bytes(self, tokens: list[int]) -> bytes: | |
"""Decodes a list of tokens into bytes. | |
>>> enc.decode_bytes([388, 372]) | |
b'hello world' | |
""" | |
return b"".join(self._decoder[token] for token in tokens) | |
def decode(self, tokens: list[int]) -> str: | |
"""Decodes a list of tokens into a string. | |
Decoded bytes are not guaranteed to be valid UTF-8. In that case, we replace | |
the invalid bytes with the replacement character "�". | |
>>> enc.decode([388, 372]) | |
'hello world' | |
""" | |
return self.decode_bytes(tokens).decode("utf-8", errors="replace") | |
def decode_tokens_bytes(self, tokens: list[int]) -> list[bytes]: | |
"""Decodes a list of tokens into a list of bytes. | |
Useful for visualising how a string is tokenised. | |
>>> enc.decode_tokens_bytes([388, 372]) | |
[b'hello', b' world'] | |
""" | |
return [self._decoder[token] for token in tokens] | |
@staticmethod | |
def train(training_data: str, vocab_size: int, pat_str: str): | |
"""Train a BPE tokeniser on some data!""" | |
mergeable_ranks = bpe_train(data=training_data, vocab_size=vocab_size, pat_str=pat_str, visualise=False) | |
return SimpleBytePairEncoding(pat_str=pat_str, mergeable_ranks=mergeable_ranks) | |
@staticmethod | |
def from_tiktoken(encoding): | |
if isinstance(encoding, str): | |
encoding = tiktoken.get_encoding(encoding) | |
return SimpleBytePairEncoding( | |
pat_str=encoding._pat_str, mergeable_ranks=encoding._mergeable_ranks | |
) | |
def bpe_encode( | |
mergeable_ranks: dict[bytes, int], input: bytes, visualise: Optional[str] = "colour" | |
) -> list[int]: | |
parts = [bytes([b]) for b in input] | |
while True: | |
# See the intermediate merges play out! | |
if visualise: | |
if visualise in ["colour", "color"]: | |
visualise_tokens(parts) | |
elif visualise == "simple": | |
print(parts) | |
# Iterate over all pairs and find the pair we want to merge the most | |
min_idx = None | |
min_rank = None | |
for i, pair in enumerate(zip(parts[:-1], parts[1:])): | |
rank = mergeable_ranks.get(pair[0] + pair[1]) | |
if rank is not None and (min_rank is None or rank < min_rank): | |
min_idx = i | |
min_rank = rank | |
# If there were no pairs we could merge, we're done! | |
if min_rank is None: | |
break | |
assert min_idx is not None | |
# Otherwise, merge that pair and leave the rest unchanged. Then repeat. | |
parts = parts[:min_idx] + [parts[min_idx] + parts[min_idx + 1]] + parts[min_idx + 2 :] | |
if visualise: | |
print() | |
tokens = [mergeable_ranks[part] for part in parts] | |
return tokens | |
def bpe_train( | |
data: str, vocab_size: int, pat_str: str, visualise: Optional[str] = "colour" | |
) -> dict[bytes, int]: | |
# First, add tokens for each individual byte value | |
if vocab_size < 2**8: | |
raise ValueError("vocab_size must be at least 256, so we can encode all bytes") | |
ranks = {} | |
for i in range(2**8): | |
ranks[bytes([i])] = i | |
# Splinter up our data into lists of bytes | |
# data = "Hello world" | |
# words = [ | |
# [b'H', b'e', b'l', b'l', b'o'], | |
# [b' ', b'w', b'o', b'r', b'l', b'd'] | |
# ] | |
words: list[list[bytes]] = [ | |
[bytes([b]) for b in word.encode("utf-8")] for word in regex.findall(pat_str, data) | |
] | |
# Now, use our data to figure out which merges we should make | |
while len(ranks) < vocab_size: | |
# Find the most common pair. This will become our next token | |
stats = collections.Counter() | |
for piece in words: | |
for pair in zip(piece[:-1], piece[1:]): | |
stats[pair] += 1 | |
most_common_pair = max(stats, key=lambda x: stats[x]) | |
token_bytes = most_common_pair[0] + most_common_pair[1] | |
token = len(ranks) | |
# Add the new token! | |
ranks[token_bytes] = token | |
# Now merge that most common pair in all the words. That is, update our training data | |
# to reflect our decision to make that pair into a new token. | |
new_words = [] | |
for word in words: | |
new_word = [] | |
i = 0 | |
while i < len(word) - 1: | |
if (word[i], word[i + 1]) == most_common_pair: | |
# We found our pair! Merge it | |
new_word.append(token_bytes) | |
i += 2 | |
else: | |
new_word.append(word[i]) | |
i += 1 | |
if i == len(word) - 1: | |
new_word.append(word[i]) | |
new_words.append(new_word) | |
words = new_words | |
# See the intermediate merges play out! | |
if visualise: | |
print(f"The current most common pair is {most_common_pair[0]} + {most_common_pair[1]}") | |
print(f"So we made {token_bytes} our {len(ranks)}th token") | |
if visualise in ["colour", "color"]: | |
print("Now the first fifty words in our training data look like:") | |
visualise_tokens([token for word in words[:50] for token in word]) | |
elif visualise == "simple": | |
print("Now the first twenty words in our training data look like:") | |
for word in words[:20]: | |
print(word) | |
print("\n") | |
return ranks | |
def visualise_tokens(token_values: list[bytes]) -> None: | |
background = [f"\u001b[48;5;{i}m" for i in [167, 179, 185, 77, 80, 68, 134]] | |
# If token boundaries do not occur at unicode character boundaries, it's unclear how best to | |
# visualise the token. Here, we'll just use the unicode replacement character to represent some | |
# fraction of a character. | |
unicode_token_values = [x.decode("utf-8", errors="replace") for x in token_values] | |
running_length = 0 | |
last_color = None | |
for token in unicode_token_values: | |
color = background[running_length % len(background)] | |
if color == last_color: | |
color = background[(running_length + 1) % len(background)] | |
assert color != last_color | |
last_color = color | |
running_length += len(token) | |
print(color + token, end="") | |
print("\u001b[0m") | |
def train_simple_encoding(): | |
gpt2_pattern = ( | |
r"""'s|'t|'re|'ve|'m|'ll|'d| ?[\p{L}]+| ?[\p{N}]+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+""" | |
) | |
import pandas as pd | |
from pathlib import Path | |
# must have 'captions' or other string column for audio captions | |
df = pd.read_csv('/path/to/csv/data/of/audio_captions.csv', sep='|') | |
data = [] | |
for value in df['captions']: | |
with Path(value).open('r') as f: | |
data.append(f.read()) | |
data = " ".join(data) | |
# vocab size here is the mergeable ranks I think | |
enc = SimpleBytePairEncoding.train(data, vocab_size=512, pat_str=gpt2_pattern) | |
# print("This is the sequence of merges performed in order to encode 'hello world':") | |
# tokens = enc.encode("Привет! Попытка обучить нейронную сеть удалась не сразу, однако, у меня получилось") | |
# assert enc.decode(tokens) == "Привет! Попытка обучить нейронную сеть удалась не сразу, однако, у меня получилось" | |
# assert enc.decode_bytes(tokens) == b"Привет! Попытка обучить нейронную сеть удалась не сразу, однако, у меня получилось" | |
# assert enc.decode_tokens_bytes(tokens) == [b"hello", b" world"] | |
return enc | |
encoder = train_simple_encoding() | |
tokens = encoder.encode("Привет! Попытка обучить нейронную сеть удалась не сразу, однако, у меня получилось") | |
# path to load original metavoice-1b english tokenizer for metadata from first stage checkpoint | |
_path = '/path/to/model/checkpoint/ckpt_000000.pth' | |
ckpt = torch.load(_path, mmap=True, weights_only=False) | |
trained_bpe_encoder = ckpt['meta'].copy() | |
trained_bpe_encoder['tokenizer']['pat_str'] = """'s|'t|'re|'ve|'m|'ll|'d| ?[\p{L}]+| ?[\p{N}]+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+""" | |
trained_bpe_encoder['tokenizer']['mergeable_ranks'] = encoder.mergeable_ranks | |
# save that! | |
torch.save(trained_bpe_encoder, 'ru_bpe_tokenizer.pt') |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import tiktoken | |
import torch | |
class TrainedBPETokeniser: | |
def __init__(self, name, pat_str, mergeable_ranks, special_tokens, offset=None) -> None: | |
self.tokenizer = tiktoken.Encoding( | |
name=name, | |
pat_str=pat_str, | |
mergeable_ranks=mergeable_ranks, | |
special_tokens=special_tokens, | |
) | |
self.offset = offset | |
def encode(self, text: str) -> list[int]: | |
# note: we add a end of text token! | |
tokens = self.tokenizer.encode(text) + [self.tokenizer.eot_token] | |
if self.offset is not None: | |
tokens = [x + self.offset for x in tokens] | |
return tokens | |
def decode(self, tokens: list[int]): | |
if self.offset is not None: | |
tokens = [x - self.offset for x in tokens] | |
return self.tokenizer.decode(tokens) | |
@property | |
def eot_token(self): | |
if self.offset is not None: | |
return self.tokenizer.eot_token + self.offset | |
else: | |
return self.tokenizer.eot_token | |
# path to load original metavoice-1b english tokenizer for metadata from first stage checkpoint | |
_path = '/path/to/model/checkpoint/ckpt_000000.pth' | |
ckpt = torch.load(_path, mmap=True, weights_only=False) | |
# for the context, you can print the tokenizer metadata here | |
from rich import pprint | |
pprint(ckpt['meta']) | |
# create a new instance of TrainedBPETokenizer | |
tokenizer = TrainedBPETokeniser(**ckpt['meta']['tokenizer']) | |
# and this a default metavoice tokenizer! |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment