Skip to content

Instantly share code, notes, and snippets.

@dermatologist
Forked from yuhanz/run-bert-tensorflow2.py
Created December 21, 2022 16:00
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save dermatologist/062c46eafe8c118334a004f6cfab663d to your computer and use it in GitHub Desktop.
Save dermatologist/062c46eafe8c118334a004f6cfab663d to your computer and use it in GitHub Desktop.
To run bert with tensorflow 2.0
pip install bert-for-tf2
pip install bert-tokenizer
pip install tensorflow-hub
pip install bert-tensorflow
pip install sentencepiece
import tensorflow_hub as hub
import tensorflow as tf
import bert
from bert import tokenization
from tensorflow.keras.models import Model
import math
max_seq_length = 128 # Your choice here.
input_word_ids = tf.keras.layers.Input(shape=(max_seq_length,), dtype=tf.int32,
name="input_word_ids")
input_mask = tf.keras.layers.Input(shape=(max_seq_length,), dtype=tf.int32,
name="input_mask")
segment_ids = tf.keras.layers.Input(shape=(max_seq_length,), dtype=tf.int32,
name="segment_ids")
bert_layer = hub.KerasLayer("https://tfhub.dev/tensorflow/bert_en_uncased_L-12_H-768_A-12/2", trainable=True)
pooled_output, sequence_output = bert_layer([input_word_ids, input_mask, segment_ids])
model = Model(inputs=[input_word_ids, input_mask, segment_ids], outputs=[pooled_output, sequence_output])
def get_masks(tokens, max_seq_length):
"""Mask for padding"""
if len(tokens)>max_seq_length:
raise IndexError("Token length more than max seq length!")
return [1]*len(tokens) + [0] * (max_seq_length - len(tokens))
def get_segments(tokens, max_seq_length):
"""Segments: 0 for the first sequence, 1 for the second"""
if len(tokens)>max_seq_length:
raise IndexError("Token length more than max seq length!")
segments = []
current_segment_id = 0
for token in tokens:
segments.append(current_segment_id)
if token == "[SEP]":
current_segment_id = 1
return segments + [0] * (max_seq_length - len(tokens))
def get_ids(tokens, tokenizer, max_seq_length):
"""Token ids from Tokenizer vocab"""
token_ids = tokenizer.convert_tokens_to_ids(tokens)
input_ids = token_ids + [0] * (max_seq_length-len(token_ids))
return input_ids
BertTokenizer = tokenization.FullTokenizer
vocabulary_file = bert_layer.resolved_object.vocab_file.asset_path.numpy()
to_lower_case = bert_layer.resolved_object.do_lower_case.numpy()
tokenizer = BertTokenizer(vocabulary_file, to_lower_case)
#tokenizer = tokenization.bert_tokenization.FullTokenizer(vocabulary_file, to_lower_case)
s = "This is a nice sentence."
stokens = tokenizer.tokenize(s)
stokens = ["[CLS]"] + stokens + ["[SEP]"]
input_ids = get_ids(stokens, tokenizer, max_seq_length)
input_masks = get_masks(stokens, max_seq_length)
input_segments = get_segments(stokens, max_seq_length)
#pool_embs, all_embs = model.predict([[input_ids],[input_masks],[input_segments]])
pool_embs, all_embs = model.predict([np.array([input_ids]), np.array([input_masks]), np.array([input_segments])])
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment