Skip to content

Instantly share code, notes, and snippets.

@joydeb28
Last active May 6, 2020 17:33
Show Gist options
  • Save joydeb28/0a5bfc7f45730a3a6f8b2dde5cb14656 to your computer and use it in GitHub Desktop.
Save joydeb28/0a5bfc7f45730a3a6f8b2dde5cb14656 to your computer and use it in GitHub Desktop.
class BertModel(object):
def __init__(self):
self.max_len = 128
bert_path = "https://tfhub.dev/tensorflow/bert_en_uncased_L-12_H-768_A-12/1"
FullTokenizer=bert.bert_tokenization.FullTokenizer
self.bert_module = hub.KerasLayer(bert_path,trainable=True)
self.vocab_file = self.bert_module.resolved_object.vocab_file.asset_path.numpy()
self.do_lower_case = self.bert_module.resolved_object.do_lower_case.numpy()
self.tokenizer = FullTokenizer(self.vocab_file,self.do_lower_case)
def get_masks(self,tokens, max_seq_length):
mask_data = [1]*len(tokens) + [0] * (max_seq_length - len(tokens))
return mask_data
def get_segments(self,tokens, max_seq_length):
'''
Segments: 0 for the first sequence,
1 for the second
'''
segments = []
segment_id = 0
for token in tokens:
segments.append(current_segment_id)
if token == "[SEP]":
segment_id = 1
'''Remaining are padded with 0'''
remaining_segment = [0] * (max_seq_length - len(tokens))
segment_data = segments + remaining_segment
return segment_data
def get_ids(self,tokens, tokenizer, max_seq_length):
token_ids = tokenizer.convert_tokens_to_ids(tokens,)
remaining_ids = [0] * (max_seq_length-len(token_ids))
input_ids = token_ids + remaining_ids
return input_ids
def get_input_data(self,sentence,maxlen):
sent_token = self.tokenizer.tokenize(sentence)
sent_token = sent_token[:maxlen]
sent_token = ["[CLS]"] + sent_token + ["[SEP]"]
id = self.get_ids(sent_token, self.tokenizer, self.max_len)
mask = self.get_masks(sent_token, self.max_len)
segment = self.get_segments(sent_token, self.max_len)
input_data = [id,mask,segment]
return input_data
def get_input_array(self,sentences):
input_ids, input_masks, input_segments = [], [], []
for sentence in tqdm(sentences,position=0, leave=True):
ids,masks,segments=self.get_input_data(sentence,self.max_len-2)
input_ids.append(ids)
input_masks.append(masks)
input_segments.append(segments)
input_array = [np.asarray(input_ids, dtype=np.int32),np.asarray(input_masks, dtype=np.int32), np.asarray(input_segments, dtype=np.int32)]
return input_array
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment