Skip to content

Instantly share code, notes, and snippets.

@morrisalp
Last active June 17, 2020 00:53
Show Gist options
  • Save morrisalp/fb76254e83f8b7f4d6584f3c6a9ea506 to your computer and use it in GitHub Desktop.
Save morrisalp/fb76254e83f8b7f4d6584f3c6a9ea506 to your computer and use it in GitHub Desktop.
minimal example of getting BERT embeddings for sentence, using TF 2.0 + Tensorflow Hub + HuggingFace tokenizers library
import tensorflow as tf
import tensorflow_hub as hub
from tokenizers import BertWordPieceTokenizer
from tensorflow.keras.layers import Input
from tensorflow.keras.models import Model
import numpy as np
class BERTPreprocessor:
SEP_TOKEN = '[SEP]'
def __init__(self, tokenizer, max_seq_length = 512):
self.tokenizer = tokenizer
self.max_seq_length = max_seq_length
def pad2max(self, ar):
assert len(ar) <= self.max_seq_length
return np.array(ar + [0] * (self.max_seq_length - len(ar)))[None]
def get_mask(self, tokens):
return self.pad2max([1] * len(tokens))
def get_segments(self, tokens):
if self.SEP_TOKEN in tokens:
sep_idx = tokens.index(self.SEP_TOKEN)
return self.pad2max([0] * sep_idx + [1] * (len(tokens) - sep_idx))
return self.pad2max([0] * len(tokens))
def get_ids(self, tokens):
return self.pad2max([self.tokenizer.token_to_id(token) for token in tokens])
def get_data(self, sentence):
enc = self.tokenizer.encode(sentence)
return enc, [self.get_ids(enc.tokens), self.get_mask(enc.tokens), self.get_segments(enc.tokens)]
bert_layer = hub.KerasLayer("https://tfhub.dev/tensorflow/bert_en_cased_L-12_H-768_A-12/2")
input_word_ids = Input(shape=(512,), dtype=tf.int32)
input_mask = Input(shape=(512,), dtype=tf.int32)
segment_ids = Input(shape=(512,), dtype=tf.int32)
pooled_output, sequence_output = bert_layer([input_word_ids, input_mask, segment_ids])
model = Model(inputs=[input_word_ids, input_mask, segment_ids], outputs=[pooled_output, sequence_output])
vocab_filename = bert_layer.resolved_object.vocab_file.asset_path.numpy().decode('utf-8')
do_lower_case = bool(bert_layer.resolved_object.do_lower_case.numpy())
tokenizer = BertWordPieceTokenizer(vocab_filename, lowercase=do_lower_case)
tokenized_encoding, data = BERTPreprocessor(tokenizer).get_data("This is a sample sentence.")
pooled_embedding, token_embeddings = model.predict(data)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment