Skip to content

Instantly share code, notes, and snippets.

@RamonYeung
Forked from ranihorev/BPE
Created June 27, 2019 08:33
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save RamonYeung/290142d5c421e0976e4d8db9dfc97cde to your computer and use it in GitHub Desktop.
Save RamonYeung/290142d5c421e0976e4d8db9dfc97cde to your computer and use it in GitHub Desktop.
Byte Pair Encoding example (Source: Sennrich et al. - https://arxiv.org/abs/1508.07909)
import re, collections
def get_stats(vocab):
pairs = collections.defaultdict(int)
for word, freq in vocab.items():
symbols = word.split()
for i in range(len(symbols)-1):
pairs[symbols[i],symbols[i+1]] += freq
return pairs
def merge_vocab(pair, v_in):
v_out = {}
bigram = re.escape(' '.join(pair))
p = re.compile(r'(?<!\S)' + bigram + r'(?!\S)')
for word in v_in:
w_out = p.sub(''.join(pair), word)
v_out[w_out] = v_in[word]
return v_out
vocab = {'l o w </w>' : 5, 'l o w e r </w>' : 2, 'n e w e s t </w>':6, 'w i d e s t </w>':3}
num_merges = 6
for i in range(num_merges):
pairs = get_stats(vocab)
best = max(pairs, key=pairs.get)
vocab = merge_vocab(best, vocab)
print(best)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment