Created
August 14, 2020 13:53
-
-
Save eerkaijun/a450cfae15e25fc4052aced6a7e95524 to your computer and use it in GitHub Desktop.
Convert raw dataset to tensors which are compatible with pretrained BERT model
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def convert_data_to_examples(train, test, DATA_COLUMN, LABEL_COLUMN): | |
train_InputExamples = train.apply(lambda x: InputExample(guid=None, # Globally unique ID for bookkeeping, unused in this case | |
text_a = x[DATA_COLUMN], | |
text_b = None, | |
label = x[LABEL_COLUMN]), axis = 1) | |
validation_InputExamples = test.apply(lambda x: InputExample(guid=None, # Globally unique ID for bookkeeping, unused in this case | |
text_a = x[DATA_COLUMN], | |
text_b = None, | |
label = x[LABEL_COLUMN]), axis = 1) | |
return train_InputExamples, validation_InputExamples | |
def convert_examples_to_tf_dataset(examples, tokenizer, max_length=128): | |
features = [] # -> will hold InputFeatures to be converted later | |
for e in examples: | |
# Documentation is really strong for this method, so please take a look at it | |
input_dict = tokenizer.encode_plus( | |
e.text_a, | |
add_special_tokens=True, | |
max_length=max_length, # truncates if len(s) > max_length | |
return_token_type_ids=True, | |
return_attention_mask=True, | |
pad_to_max_length=True, # pads to the right by default | |
truncation=True | |
) | |
input_ids, token_type_ids, attention_mask = (input_dict["input_ids"], | |
input_dict["token_type_ids"], input_dict['attention_mask']) | |
features.append( | |
InputFeatures( | |
input_ids=input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids, label=e.label | |
) | |
) | |
def gen(): | |
for f in features: | |
yield ( | |
{ | |
"input_ids": f.input_ids, | |
"attention_mask": f.attention_mask, | |
"token_type_ids": f.token_type_ids, | |
}, | |
f.label, | |
) | |
return tf.data.Dataset.from_generator( | |
gen, | |
({"input_ids": tf.int32, "attention_mask": tf.int32, "token_type_ids": tf.int32}, tf.int64), | |
( | |
{ | |
"input_ids": tf.TensorShape([None]), | |
"attention_mask": tf.TensorShape([None]), | |
"token_type_ids": tf.TensorShape([None]), | |
}, | |
tf.TensorShape([]), | |
), | |
) | |
DATA_COLUMN = 'sentence' | |
LABEL_COLUMN = 'polarity' | |
# train and test is your dataset | |
train_InputExamples, validation_InputExamples = convert_data_to_examples(train, test, DATA_COLUMN, LABEL_COLUMN) | |
train_data = convert_examples_to_tf_dataset(list(train_InputExamples), tokenizer) | |
train_data = train_data.shuffle(100).batch(32).repeat(2) | |
validation_data = convert_examples_to_tf_dataset(list(validation_InputExamples), tokenizer) | |
validation_data = validation_data.batch(32) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment