Skip to content

Instantly share code, notes, and snippets.

@odashi
Last active August 15, 2017 12:23
Show Gist options
  • Star 1 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save odashi/840b92e18ddf984bfacbdcd10d883f66 to your computer and use it in GitHub Desktop.
Save odashi/840b92e18ddf984bfacbdcd10d883f66 to your computer and use it in GitHub Desktop.
Byte-pair encoding tools
#!/usr/bin/env python3
import sys
from argparse import ArgumentParser
from collections import defaultdict
def parse_args():
p = ArgumentParser('Converts word to integer using byte-pair encoding.')
p.add_argument(
'--input',
type=str, metavar='FILE', required=True, help='source corpus')
p.add_argument(
'--output',
type=str, metavar='FILE', required=True, help='output corpus')
p.add_argument(
'--vocab',
type=str, metavar='FILE', required=True, help='BPE vocabulary file')
args = p.parse_args()
return args
def load_vocab(filename):
vocab = {}
chars = defaultdict(lambda: '<unk>')
ops = []
with open(filename) as fp:
for line in fp:
wid, left, right, *rest = line.strip().split('\t')
if len(rest) == 0:
vocab[left] = wid
chars[left] = left
else:
vocab[left + ' ' + right] = wid
before = '\t' + left + '\t' + right + '\t'
after = '\t' + left + ' ' + right + '\t'
ops.append((before, after))
return vocab, chars, ops
def convert(line, vocab, chars, ops, memo):
wids = []
for word in line.split():
if word in memo:
wids.append(memo[word])
else:
subwords = '\t' + '\t'.join(chars[c] for c in list(word)) + '\t</w>\t'
for before, after in ops:
subwords = subwords.replace(before, after)
result = ' '.join(vocab[x] for x in subwords.strip('\t').split('\t'))
memo[word] = result
wids.append(result)
return ' '.join(wids)
def main():
args = parse_args()
vocab, chars, ops = load_vocab(args.vocab)
memo = {}
with open(args.input) as ifp, open(args.output, 'w') as ofp:
for i, line in enumerate(ifp):
print(convert(line, vocab, chars, ops, memo), file=ofp)
if (i + 1) % 100 == 0:
print('Processed %d lines.' % (i + 1), end='\r', file=sys.stderr)
print('Processed %d lines.' % (i + 1), file=sys.stderr)
if __name__ == '__main__':
main()
#!/usr/bin/env python3
import multiprocessing
import math
import sys
import time
from argparse import ArgumentParser
from collections import defaultdict
def parse_args():
p = ArgumentParser('Constructs vocabulary file.')
p.add_argument(
'--input',
type=str, metavar='FILE', required=True, help='source corpus')
p.add_argument(
'--output',
type=str, metavar='FILE', required=True, help='vocabulary file')
p.add_argument(
'--size',
type=int, metavar='N', required=True, help='vicabulary size')
p.add_argument(
'--min-freq',
type=int, metavar='N', required=True, help='minimum occurence per each character')
p.add_argument(
'--threads',
type=int, metavar='N', required=True, help='number of threads')
args = p.parse_args()
assert args.size > 3
return args
def trace(*args, nolf=False):
print(*args, file=sys.stderr, end='\r' if nolf else '\n')
sys.stderr.flush()
def word2key(word):
return '\t' + '\t'.join(list(word)) + '\t</w>\t'
def key2subwords(key):
return key.strip('\t').split('\t')
def subwords2key(subwords):
return '\t' + '\t'.join(subwords) + '\t'
def calculate_unigram_freq(encoding):
freq = defaultdict(int)
for key, val in encoding.items():
for sw in key2subwords(key):
freq[sw] += val
return freq
def calculate_bigram_freq(encoding):
freq = defaultdict(int)
for key, val in encoding.items():
subwords = key2subwords(key)
for i in range(len(subwords) - 1):
freq[subwords[i], subwords[i + 1]] += val
return freq
def load_initial_encoding(filename):
encoding = defaultdict(int)
with open(filename) as fp:
for i, line in enumerate(fp):
for word in line.split():
key = word2key(word)
encoding[key] += 1
if (i + 1) % 10000 == 0:
trace('Loaded', i + 1, 'lines', nolf=True)
trace('Loaded', i + 1, 'lines')
trace('#unique encodings:', len(encoding))
return i + 1, encoding
def filter_chars(encoding, min_freq):
freq = calculate_unigram_freq(encoding)
trace('#unique characters:', len(freq))
filtered = [c for c in freq if freq[c] >= min_freq]
trace('#filtered characters:', len(filtered))
result = defaultdict(int)
for key, val in encoding.items():
subwords = key2subwords(key)
new_subwords = [(sw if sw in filtered else '<unk>') for sw in subwords]
new_key = subwords2key(new_subwords)
result[new_key] += val
trace('#filtered encodings:', len(result))
return filtered, result
def make_shards(encoding, n):
shards = [{} for _ in range(n)]
for i, (key, val) in enumerate(encoding.items()):
shards[i % n][key] = val
return shards
def merge_freqs(freq, diffs):
for diff in diffs:
for key, val in diff.items():
freq[key] += val
def merge_bigram(config):
encoding, (left, right) = config
before = '\t' + left + '\t' + right + '\t'
after = '\t' + left + ' ' + right + '\t'
new_encoding = {}
diff = defaultdict(int)
for key, val in encoding.items():
new_key = key.replace(before, after)
if new_key != key:
subwords = key2subwords(new_key)
for i in range(len(subwords) - 1):
diff[subwords[i], subwords[i + 1]] += val
subwords = key2subwords(key)
for i in range(len(subwords) - 1):
diff[subwords[i], subwords[i + 1]] -= val
new_encoding[new_key] = val
return new_encoding, {key: val for key, val in diff.items() if val != 0}
def main():
args = parse_args()
total_begin_time = time.time()
max_words = args.size - 3
num_lines, encoding = load_initial_encoding(args.input)
chars, encoding = filter_chars(encoding, args.min_freq)
assert len(chars) <= max_words
pool = multiprocessing.Pool(args.threads)
shards = make_shards(encoding, args.threads)
for i, shard in enumerate(shards):
print('Shard %d size: %d' % (i, len(shard)))
freq = defaultdict(int)
merge_freqs(freq, pool.imap_unordered(calculate_bigram_freq, shards))
ops = []
for i in range(len(chars), max_words):
begin_time = time.time()
left, right = max(freq, key=freq.get)
merged_freq = freq[left, right]
results = pool.map(merge_bigram, ((x, (left, right)) for x in shards))
shards = [x[0] for x in results]
merge_freqs(freq, (x[1] for x in results))
ops.append((left, right))
elapsed = time.time() - begin_time
trace('Merged %d/%d: "%s" + "%s" (freq=%d, time=%fs)' % (i + 1, max_words, left, right, merged_freq, elapsed))
trace('Writing vocabulary file ...')
freq = defaultdict(int)
merge_freqs(freq, pool.imap_unordered(calculate_unigram_freq, shards))
num_unk = freq['<unk>'] if '<unk>' in freq else 0
with open(args.output, 'w') as fp:
print('0\t<unk>\t%d' % num_unk, file=fp)
print('1\t<s>\t%d' % num_lines, file=fp)
print('2\t</s>\t%d' % num_lines, file=fp)
print('3\t</w>\t%d' % freq['</w>'], file=fp)
for i, c in enumerate(sorted(chars)):
if c == '</w>':
continue
print('%d\t%s\t%d' % (i + 4, c, freq[c]), file=fp)
for i, bigram in enumerate(ops):
key = ' '.join(bigram)
print('%d\t%s\t%s\t%d' % (i + 3 + len(chars), bigram[0], bigram[1], freq[key]), file=fp)
total_elapsed = time.time() - total_begin_time
trace('Total time elapsed: %fs' % total_elapsed)
if __name__ == '__main__':
main()
#!/usr/bin/env python3
import multiprocessing
import math
import sys
import time
from argparse import ArgumentParser
from collections import defaultdict
def parse_args():
p = ArgumentParser('Constructs vocabulary file.')
p.add_argument(
'--input',
type=str, metavar='FILE', required=True, help='source corpus')
p.add_argument(
'--output',
type=str, metavar='FILE', required=True, help='vocabulary file')
p.add_argument(
'--size',
type=int, metavar='N', required=True, help='vicabulary size')
p.add_argument(
'--min-freq',
type=int, metavar='N', required=True, help='minimum occurence per each character')
p.add_argument(
'--threads',
type=int, metavar='N', required=True, help='number of threads')
args = p.parse_args()
assert args.size > 3
return args
def trace(*args, nolf=False):
print(*args, file=sys.stderr, end='\r' if nolf else '\n')
sys.stderr.flush()
def word2key(word):
return '\t' + '\t'.join(list(word)) + '\t</w>\t'
def key2subwords(key):
return key.strip('\t').split('\t')
def subwords2key(subwords):
return '\t' + '\t'.join(subwords) + '\t'
def calculate_unigram_freq(encoding):
freq = defaultdict(int)
for key, val in encoding.items():
for sw in key2subwords(key):
freq[sw] += val
return freq
def calculate_bigram_freq(encoding):
freq = defaultdict(int)
for key, val in encoding.items():
subwords = key2subwords(key)
for i in range(len(subwords) - 1):
freq[subwords[i], subwords[i + 1]] += val
return freq
def load_initial_encoding(filename):
encoding = defaultdict(int)
with open(filename) as fp:
for i, line in enumerate(fp):
for word in line.split():
key = word2key(word)
encoding[key] += 1
if (i + 1) % 10000 == 0:
trace('Loaded', i + 1, 'lines', nolf=True)
trace('Loaded', i + 1, 'lines')
trace('#unique encodings:', len(encoding))
return i + 1, encoding
def filter_chars(encoding, min_freq):
freq = calculate_unigram_freq(encoding)
trace('#unique characters:', len(freq))
filtered = [c for c in freq if freq[c] >= min_freq]
trace('#filtered characters:', len(filtered))
result = defaultdict(int)
for key, val in encoding.items():
subwords = key2subwords(key)
new_subwords = [(sw if sw in filtered else '<unk>') for sw in subwords]
new_key = subwords2key(new_subwords)
result[new_key] += val
trace('#filtered encodings:', len(result))
return filtered, result
def make_shards(encoding, n):
shards = [{} for _ in range(n)]
for i, (key, val) in enumerate(encoding.items()):
shards[i % n][key] = val
return shards
def merge_freqs(freqs):
total = defaultdict(int)
for freq in freqs:
for key, val in freq.items():
total[key] += val
return total
def merge_bigram(encoding):
bigram = encoding[0]
before = '\t' + bigram[0] + '\t' + bigram[1] + '\t'
after = '\t' + bigram[0] + ' ' + bigram[1] + '\t'
result = {}
for key, val in encoding.items():
if key == 0:
continue
new_key = key.replace(before, after)
result[new_key] = val
return result
def main():
args = parse_args()
total_begin_time = time.time()
max_words = args.size - 3
num_lines, encoding = load_initial_encoding(args.input)
chars, encoding = filter_chars(encoding, args.min_freq)
assert len(chars) <= max_words
pool = multiprocessing.Pool(args.threads)
shards = make_shards(encoding, args.threads)
for i, shard in enumerate(shards):
print('Shard %d size: %d' % (i, len(shard)))
ops = []
for i in range(len(chars), max_words):
begin_time = time.time()
freq = merge_freqs(pool.imap_unordered(calculate_bigram_freq, shards))
bigram = max(freq, key=freq.get)
for shard in shards:
shard[0] = bigram
shards = pool.map(merge_bigram, shards)
ops.append(bigram)
elapsed = time.time() - begin_time
l_str = ''.join(bigram[0].split(' '))
r_str = ''.join(bigram[1].split(' '))
trace('Merged %d/%d: %s + %s (freq=%d, time=%fs)' % (i + 1, max_words, l_str, r_str, freq[bigram], elapsed))
trace('Writing vocabulary file ...')
freq = merge_freqs(pool.imap_unordered(calculate_unigram_freq, shards))
num_unk = freq['<unk>'] if '<unk>' in freq else 0
with open(args.output, 'w') as fp:
print('0\t<unk>\t%d' % num_unk, file=fp)
print('1\t<s>\t%d' % num_lines, file=fp)
print('2\t</s>\t%d' % num_lines, file=fp)
print('3\t</w>\t%d' % freq['</w>'], file=fp)
for i, c in enumerate(sorted(chars)):
if c == '</w>':
continue
print('%d\t%s\t%d' % (i + 4, c, freq[c]), file=fp)
for i, bigram in enumerate(ops):
key = ' '.join(bigram)
print('%d\t%s\t%s\t%d' % (i + 3 + len(chars), bigram[0], bigram[1], freq[key]), file=fp)
total_elapsed = time.time() - total_begin_time
trace('Total time elapsed: %fs' % total_elapsed)
if __name__ == '__main__':
main()
#!/usr/bin/env python3
import sys
from argparse import ArgumentParser
from collections import defaultdict
def parse_args():
p = ArgumentParser('Converts integer to words using byte-pair encoding.')
p.add_argument(
'--input',
type=str, metavar='FILE', required=True, help='source corpus')
p.add_argument(
'--output',
type=str, metavar='FILE', required=True, help='output corpus')
p.add_argument(
'--vocab',
type=str, metavar='FILE', required=True, help='BPE vocabulary file')
args = p.parse_args()
return args
def load_vocab(filename):
vocab = defaultdict(lambda: '<unk>')
with open(filename) as fp:
for line in fp:
wid, left, right, *rest = line.strip().split('\t')
if len(rest) == 0:
vocab[int(wid)] = left.split(' ')
else:
vocab[int(wid)] = left.split(' ') + right.split(' ')
return vocab
def convert(line, vocab):
cache = []
words = []
for wid in line.split():
cache += vocab[int(wid)]
if cache[-1] == '</w>':
words.append(''.join(cache[:-1]))
cache = []
if cache:
words.append(''.join(cache))
return ' '.join(words)
def main():
args = parse_args()
vocab = load_vocab(args.vocab)
with open(args.input) as ifp, open(args.output, 'w') as ofp:
for line in ifp:
print(convert(line, vocab), file=ofp)
if __name__ == '__main__':
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment