Created
October 9, 2019 14:13
-
-
Save VictorSanh/3c080b6145350285a46311ae89375cc2 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import tensorflow as tf | |
from transformers import BertTokenizer, TFBertForSequenceClassification | |
import numpy as np | |
# seq_length = 128 | |
# nb_examples = 1 | |
# voc_size = 25000 | |
# input_ids = tf.random.uniform((nb_examples,seq_length), | |
# maxval=voc_size, | |
# dtype=tf.dtypes.int32) | |
# attention_mask = tf.fill(tf.shape(input_ids), 1) | |
# token_type_ids = tf.zeros((nb_examples, seq_length)) | |
# model = TFBertForSequenceClassification.from_pretrained('bert-base-uncased') | |
# inputs = [input_ids, attention_mask, token_type_ids] | |
# test_1 = model(inputs=inputs) | |
# inputs = {'input_ids': input_ids, 'attention_mask': attention_mask, 'token_type_ids': token_type_ids} | |
# test_2 = model(inputs=inputs) | |
saved_model = TFBertForSequenceClassification.from_pretrained('bert-base-uncased') | |
# saved_model = tf.keras.models.load_model('examples/serving/saved_model/bertseq/1') or an already trained model for instance from https://github.com/huggingface/transformers/blob/master/examples/run_tf_glue.py | |
tokenizer = BertTokenizer.from_pretrained('bert-base-cased') | |
class FullModel(tf.keras.Model): | |
def __init__(self, | |
add_special_tokens=True, | |
max_length=128, | |
pad_on_left=False, | |
pad_token=0, | |
pad_token_segment_id=0, | |
mask_padding_with_zero=True): | |
super(FullModel, self).__init__() | |
self.add_special_tokens = add_special_tokens | |
self.max_length = max_length | |
self.pad_on_left = pad_on_left | |
self.pad_token = pad_token | |
self.pad_token_segment_id = pad_token_segment_id | |
self.mask_padding_with_zero = mask_padding_with_zero | |
self.tokenizer = tokenizer | |
self.bert = saved_model | |
# @tf.function(input_signature=[tf.TensorSpec([None], tf.string)]) | |
def prepare_batch(self, texts): | |
""" | |
Highly insspired from https://github.com/huggingface/transformers/blob/master/transformers/data/processors/glue.py | |
Related to https://github.com/tensorflow/tensorflow/issues/31055 | |
""" | |
def _tokenize(t): | |
inputs = self.tokenizer.encode_plus(t, max_length=self.max_length, add_special_tokens=self.add_special_tokens) | |
input_ids, token_type_ids = inputs["input_ids"], inputs["token_type_ids"] | |
attention_mask = [1 if self.mask_padding_with_zero else 0] * len(input_ids) | |
padding_length = self.max_length - len(input_ids) | |
if self.pad_on_left: | |
input_ids = ([self.pad_token] * padding_length) + input_ids | |
attention_mask = ([0 if self.mask_padding_with_zero else 1] * padding_length) + attention_mask | |
token_type_ids = ([self.pad_token_segment_id] * padding_length) + token_type_ids | |
else: | |
input_ids = input_ids + ([self.pad_token] * padding_length) | |
attention_mask = attention_mask + ([0 if self.mask_padding_with_zero else 1] * padding_length) | |
token_type_ids = token_type_ids + ([self.pad_token_segment_id] * padding_length) | |
return input_ids, attention_mask, token_type_ids | |
rslt = list(map(_tokenize, texts)) | |
inputs_dict = { | |
"input_ids": tf.constant([i[0] for i in rslt]), | |
"attention_mask": tf.constant([i[1] for i in rslt]), | |
"token_type_ids": tf.constant([i[2] for i in rslt]) | |
} | |
return inputs_dict | |
def call(self, texts): | |
inputs_dict = self.prepare_batch(texts) | |
return self.bert(inputs=inputs_dict) | |
full_model = FullModel() | |
text1 = 'Hello my name is Victor.' | |
text2 = 'Goodbye, his name is John.' | |
test_3 = full_model([text1, text2]) | |
# full_model.predict([text1], batch_size=1) | |
full_model._set_inputs(np.array([text1])) | |
full_model.save('examples/serving/saved_model/fullbert/0') |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment