Created
November 5, 2019 20:44
-
-
Save caleb-kaiser/c3ce8855065a1c130b9c18f37aa7ce1b to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import tensorflow as tf | |
import os | |
import json | |
import regex as re | |
from functools import lru_cache | |
import requests | |
import boto3 | |
@lru_cache() | |
def bytes_to_unicode(): | |
bs = ( | |
list(range(ord("!"), ord("~") + 1)) | |
+ list(range(ord("¡"), ord("¬") + 1)) | |
+ list(range(ord("®"), ord("ÿ") + 1)) | |
) | |
cs = bs[:] | |
n = 0 | |
for b in range(2 ** 8): | |
if b not in bs: | |
bs.append(b) | |
cs.append(2 ** 8 + n) | |
n += 1 | |
cs = [chr(n) for n in cs] | |
return dict(zip(bs, cs)) | |
def get_pairs(word): | |
pairs = set() | |
prev_char = word[0] | |
for char in word[1:]: | |
pairs.add((prev_char, char)) | |
prev_char = char | |
return pairs | |
class Encoder: | |
def __init__(self, encoder, bpe_merges, errors="replace"): | |
self.encoder = encoder | |
self.decoder = {v: k for k, v in self.encoder.items()} | |
self.errors = errors | |
self.byte_encoder = bytes_to_unicode() | |
self.byte_decoder = {v: k for k, v in self.byte_encoder.items()} | |
self.bpe_ranks = dict(zip(bpe_merges, range(len(bpe_merges)))) | |
self.cache = {} | |
self.pat = re.compile( | |
r"""'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+""" | |
) | |
def bpe(self, token): | |
if token in self.cache: | |
return self.cache[token] | |
word = tuple(token) | |
pairs = get_pairs(word) | |
if not pairs: | |
return token | |
while True: | |
bigram = min(pairs, key=lambda pair: self.bpe_ranks.get(pair, float("inf"))) | |
if bigram not in self.bpe_ranks: | |
break | |
first, second = bigram | |
new_word = [] | |
i = 0 | |
while i < len(word): | |
try: | |
j = word.index(first, i) | |
new_word.extend(word[i:j]) | |
i = j | |
except: | |
new_word.extend(word[i:]) | |
break | |
if word[i] == first and i < len(word) - 1 and word[i + 1] == second: | |
new_word.append(first + second) | |
i += 2 | |
else: | |
new_word.append(word[i]) | |
i += 1 | |
new_word = tuple(new_word) | |
word = new_word | |
if len(word) == 1: | |
break | |
else: | |
pairs = get_pairs(word) | |
word = " ".join(word) | |
self.cache[token] = word | |
return word | |
def encode(self, text): | |
bpe_tokens = [] | |
for token in re.findall(self.pat, text): | |
token = "".join(self.byte_encoder[b] for b in token.encode("utf-8")) | |
bpe_tokens.extend(self.encoder[bpe_token] for bpe_token in self.bpe(token).split(" ")) | |
return bpe_tokens | |
def decode(self, tokens): | |
text = "".join([self.decoder[token] for token in tokens]) | |
text = bytearray([self.byte_decoder[c] for c in text]).decode("utf-8", errors=self.errors) | |
return text | |
def get_encoder(): | |
s3 = boto3.client("s3") | |
encoder = json.load( | |
s3.get_object(Bucket="cortex-test-project", Key="FeynmanBot/encoder.json")["Body"] | |
) | |
bpe_data = ( | |
s3.get_object(Bucket="cortex-test-project", Key="FeynmanBot/vocab.bpe")["Body"] | |
.read() | |
.decode("utf-8") | |
) | |
bpe_merges = [tuple(merge_str.split()) for merge_str in bpe_data.split("\n")[1:-1]] | |
return Encoder(encoder=encoder, bpe_merges=bpe_merges) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment