Skip to content

Instantly share code, notes, and snippets.

@8enmann
Created June 24, 2019 03:29
Show Gist options
  • Save 8enmann/86be66859735fb7e33a2f36041fa433c to your computer and use it in GitHub Desktop.
Save 8enmann/86be66859735fb7e33a2f36041fa433c to your computer and use it in GitHub Desktop.
Approximate BPE implementation.
"""Implements an approximate BPE encoding over bytes with some tricks for efficiency.
https://arxiv.org/pdf/1508.07909.pdf section 3.2.
Basic algorithm from the paper:
Initialize the vocab with the character vocabulary
Each word is a sequence of characters plus an enod of word symbol '·'
Count all symbol pairs
Replace each occurence of the most frequent pair ('a', 'b') with 'ab'.
Each merge represents a character n-gram
Frequent n-grams are merged into a single symbol.
Repeat until max vocab size or computation budget is reached.
Unlike the paper, this implementation operates directly on utf-8 bytes,
so should work for any language or data type with no modification.
It provides the option to do multiple replacements per iteration for increased speed.
Encoding using a computed vocab is done greedily instead of by the standard algorithm.
TODO: benchmark against original.
"""
import multiprocessing as mp
from collections import Counter, deque
from typing import Dict, Iterable, List, Set, Tuple
import tqdm
def get_pairs(seq: Iterable) -> Iterable[Tuple]:
"""Yield a sliding window of length 2 from seq."""
d = deque(maxlen=2)
# Consume first bit
it = iter(seq)
for _ in range(2):
d.append(next(it))
yield tuple(d)
for i in it:
d.append(i)
yield tuple(d)
class Worker(mp.Process):
"""Computes counts on a subset of the corpus.
Waits for the master to tell it what to merge based on its siblings.
Queues are child -> parent only.
`top_k` is read only.
"""
def __init__(
self,
top_k_ready: mp.Condition,
top_k: 'DictProxy',
count_q: mp.Queue,
vocab_q: mp.Queue,
corpus: str):
super(Worker, self).__init__()
self.top_k_ready = top_k_ready
self.top_k = top_k
self.vocab_q = vocab_q
self.count_q = count_q
self.corpus = corpus
self.byte_list: Iterable[bytes] = None
def run(self):
"""This shouldn't be called directly; call `worker.start()`."""
print('started', self.name)
self.byte_list = str_to_byte_list(self.corpus)
self.vocab = set(self.byte_list)
self.vocab_q.put(self.vocab)
while True:
counts = Counter(get_pairs(self.byte_list))
self.count_q.put(counts)
# Wait for main thread to send top k merges
with self.top_k_ready:
self.top_k_ready.wait()
if len(self.top_k) == 0:
break
self.byte_list = list(merge(self.top_k, self.byte_list))
def compute_vocab_multi(
corpus: str,
max_vocab_size:int=3000,
max_merges:int=10, top_k=1,
n:int=mp.cpu_count()) -> Set[bytes]:
"""Multiprocess implementation of approximate BPE.
Divides the corpus among n workers.
Args:
corpus: The corpus to encode. Could scale better by taking a list of filenames.
max_vocab_size: Stop after generating this many vocab entries.
max_merges: Stop after this many rounds.
top_k: Each round merge the top k pairs. Standard BPE sets top_k=1.
Returns:
A set of all the vocab entries generated, each of which is a `bytes`.
"""
top_k_ready = mp.Condition()
vocab_q = mp.Queue()
count_q = mp.Queue()
chunk_size = len(corpus) // n
counts = Counter()
vocab = set()
with mp.Manager() as manager:
to_merge = manager.dict()
procs = []
print('starting workers')
for i in range(n):
procs.append(Worker(
top_k_ready,
to_merge,
count_q,
vocab_q,
# These overlap on purpose
corpus[i * chunk_size:(i+1) * chunk_size + 1],
max_merges
))
procs[-1].start()
# Get inital vocab from each worker.
print('waiting for vocab from worker')
for _ in range(n):
vocab.update(vocab_q.get())
print('got vocab', vocab)
for i in range(max_merges):
# Get counts from each worker.
for _ in range(n):
counts.update(count_q.get())
print(counts)
to_merge.clear()
to_merge.update({x[0]: b''.join(x[0]) for x in counts.most_common(top_k)})
vocab.update(to_merge.values())
with top_k_ready:
top_k_ready.notify_all()
if len(vocab) >= max_vocab_size:
break
# Tell workers to stop.
to_merge.clear()
with top_k_ready:
top_k_ready.notify_all()
for p in procs:
p.join(1)
return vocab
def merge(to_merge: Dict[Tuple[bytes], bytes], seq: Iterable) -> Iterable:
"""Given a set of requested merges, go through the sequence and do the merges."""
to_merge = {x: b''.join(x) for x in to_merge.keys()}
just_merged = False
for pair in get_pairs(seq):
if just_merged:
just_merged = False
continue
if pair in to_merge:
just_merged = True
yield to_merge[pair]
else:
yield pair[0]
if not just_merged:
yield pair[1]
def str_to_byte_list(s: str) -> Iterable[bytes]:
return [bytes([x]) for x in s.encode('utf8')]
def compute_vocab(corpus: str, max_vocab_size:int=3000, max_merges:int=10, top_k=1) -> Set[bytes]:
"""Single threaded implementation of approximate BPE.
Args:
corpus: The corpus to encode. Could scale better by taking a list of filenames.
max_vocab_size: Stop after generating this many vocab entries.
max_merges: Stop after this many rounds.
top_k: Each round merge the top k pairs. Standard BPE sets top_k=1.
Returns:
A set of all the vocab entries generated, each of which is a `bytes`.
"""
if len(corpus) < min(max_merges, max_vocab_size):
raise Exception('Corpus must be bigger than max_merges')
l = str_to_byte_list(corpus)
vocab = set(l)
for i in tqdm.trange(max_merges):
counts = Counter(get_pairs(l))
# Merge the most common.
to_merge = {x[0]: b''.join(x[0]) for x in counts.most_common(top_k)}
vocab.update(to_merge.values())
l = list(merge(to_merge, l))
if len(vocab) >= max_vocab_size:
break
return vocab
class Encoder:
DEFAULT_VOCAB_FILENAME = 'vocab.bpe'
# Null bytes unlikely to occur in natural encoded text.
DELIM = b'\0\n'
# Must be 2 characters long because otherwise probably won't have intermediate merges for the greedy encoder to pick up.
EOF = b'\0F'
UNK = b'\0UNK'
def __init__(self, vocab: Iterable[bytes]=None, vocab_file: str=DEFAULT_VOCAB_FILENAME):
if vocab:
self.vocab = vocab
else:
self.vocab = self.load(vocab_file)
# Append special characters.
self.vocab += [self.EOF, self.UNK]
# break keys into tuples for faster match?
self.encoder = {x:i for i,x in enumerate(vocab)}
self.decoder = {i:x for i,x in enumerate(vocab)}
self.max_length = max(map(len, self.vocab))
self.UNK_EMB = len(self.vocab) - 1
def encode(self, corpus: str) -> Iterable[int]:
"""Greedily encode `corpus` according to the vocab."""
b = corpus.encode('utf8')
start = 0
while start < len(b):
match = self.UNK_EMB
for end in range(0, self.max_length):
end += 1 + start
substr = b[start:end]
new_match = self.encoder.get(substr)
if new_match is not None:
match = new_match
if end < len(b):
continue
yield match
start += max(1, len(substr) - 1)
break
def decode(self, corpus: Iterable[int], errors='ignore') -> str:
"""Decode `corpus` according to the vocab."""
return b''.join([self.decoder[x] for x in corpus]).decode('utf8', errors=errors)
@classmethod
def save(cls, vocab: Iterable[bytes], filename:str=DEFAULT_VOCAB_FILENAME):
with open(filename, 'wb') as f:
for v in vocab:
f.write(v + cls.DELIM)
@classmethod
def load(cls, filename:str=DEFAULT_VOCAB_FILENAME):
with open(filename, 'rb') as f:
return f.read().split(cls.DELIM)[:-1]
def main():
CORPUS_FILE = '/Users/ben/data/wikitext-2/wiki.train.tokens'
with open(CORPUS_FILE) as f:
corpus = f.read()
print(len(corpus))
vocab = compute_vocab(corpus[:10000], max_merges=100, top_k=10)
print(len(vocab))
# Save the mapping
Encoder.save(vocab)
print(len(Encoder.load()))
if __name__ == '__main__':
main()
"""Tests for bpe.py.
One of the tests uses a separate process.
The others mock out multiprocessing functionality, so should be lightweight.
"""
from unittest import mock
import pytest
import multiprocessing
import bpe
def test_get_pairs():
s = 'aaabbb'
pairs = list(bpe.get_pairs(s))
assert len(pairs) == len(s) - 1
assert list(bpe.get_pairs('abc')) == [('a', 'b'), ('b', 'c')]
def test_str_to_byte_list():
assert [b'a',b'b'] == bpe.str_to_byte_list('ab')
def test_merge():
l = bpe.str_to_byte_list('abc')
to_merge = {x: b''.join(x) for x in [(b'a',b'b')]}
merged = list(bpe.merge(to_merge, l))
assert merged == [b'ab', b'c']
def test_compute_vocab_simple():
TEST = 'The quick brown fox jumped over the lazy dog. Wow! Amazing.'
vocab = bpe.compute_vocab(TEST)
assert 41 == len(vocab)
def test_encode():
encoder = bpe.Encoder(bpe.str_to_byte_list('abcdef'))
test_str = 'aabb'
encoded = list(encoder.encode(test_str))
assert encoded == [0, 0, 1, 1]
assert test_str == encoder.decode(encoded)
# Test UNK
assert encoder.UNK.decode('utf8') == encoder.decode(encoder.encode('t'))
@mock.patch('multiprocessing.Queue')
@mock.patch('bpe.Worker')
@mock.patch('multiprocessing.Condition')
def test_compute_vocab_multi(Condition, MockWorker, Queue):
q = Queue.return_value
q.get.side_effect = [
# Return initial vocab.
{b'a'},
# Return the first set of counts.
{(b'a', b'a'): 3, (b'b', b'b'): 2}]
out = bpe.compute_vocab_multi('aaabb', n=1, max_vocab_size=2)
assert 'aaabb' in MockWorker.call_args[0]
assert out == {b'aa', b'a'}
@mock.patch('multiprocessing.Queue')
@mock.patch('bpe.Worker')
@mock.patch('multiprocessing.Condition')
def test_compute_vocab_multi_corpus_partition(Condition, MockWorker, Queue):
# Get the instance
q = Queue.return_value
# Return the same thing every time.
q.get.return_value = []
out = bpe.compute_vocab_multi('aaabb', n=2, max_vocab_size=0)
assert 'aaa' in MockWorker.call_args_list[0][0]
assert 'abb' in MockWorker.call_args_list[1][0]
# Queue returned nothing every time.
assert out == set()
def test_worker():
top_k_ready = multiprocessing.Condition()
with multiprocessing.Manager() as m:
top_k = m.dict()
count_q = multiprocessing.Queue()
vocab_q = multiprocessing.Queue()
worker = bpe.Worker(top_k_ready, top_k, count_q, vocab_q, 'aaabb')
worker.start()
assert vocab_q.get() == {b'a', b'b'}
counts = count_q.get()
assert counts == {(b'a', b'a'): 2, (b'a', b'b'): 1, (b'b', b'b'): 1}
top_k.update({x[0]: b''.join(x[0]) for x in counts.most_common(2)})
with top_k_ready:
top_k_ready.notify()
# Round 2.
counts = count_q.get()
assert counts == {(b'aa', b'ab'): 1, (b'ab', b'b'): 1}
# Finish.
top_k.clear()
with top_k_ready:
top_k_ready.notify()
worker.join()
assert not worker.is_alive()
if __name__ == '__main__':
pytest.main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment