Skip to content

Instantly share code, notes, and snippets.

@eerkaijun
Created August 14, 2020 13:53
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save eerkaijun/a450cfae15e25fc4052aced6a7e95524 to your computer and use it in GitHub Desktop.
Save eerkaijun/a450cfae15e25fc4052aced6a7e95524 to your computer and use it in GitHub Desktop.
Convert raw dataset to tensors which are compatible with pretrained BERT model
def convert_data_to_examples(train, test, DATA_COLUMN, LABEL_COLUMN):
train_InputExamples = train.apply(lambda x: InputExample(guid=None, # Globally unique ID for bookkeeping, unused in this case
text_a = x[DATA_COLUMN],
text_b = None,
label = x[LABEL_COLUMN]), axis = 1)
validation_InputExamples = test.apply(lambda x: InputExample(guid=None, # Globally unique ID for bookkeeping, unused in this case
text_a = x[DATA_COLUMN],
text_b = None,
label = x[LABEL_COLUMN]), axis = 1)
return train_InputExamples, validation_InputExamples
def convert_examples_to_tf_dataset(examples, tokenizer, max_length=128):
features = [] # -> will hold InputFeatures to be converted later
for e in examples:
# Documentation is really strong for this method, so please take a look at it
input_dict = tokenizer.encode_plus(
e.text_a,
add_special_tokens=True,
max_length=max_length, # truncates if len(s) > max_length
return_token_type_ids=True,
return_attention_mask=True,
pad_to_max_length=True, # pads to the right by default
truncation=True
)
input_ids, token_type_ids, attention_mask = (input_dict["input_ids"],
input_dict["token_type_ids"], input_dict['attention_mask'])
features.append(
InputFeatures(
input_ids=input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids, label=e.label
)
)
def gen():
for f in features:
yield (
{
"input_ids": f.input_ids,
"attention_mask": f.attention_mask,
"token_type_ids": f.token_type_ids,
},
f.label,
)
return tf.data.Dataset.from_generator(
gen,
({"input_ids": tf.int32, "attention_mask": tf.int32, "token_type_ids": tf.int32}, tf.int64),
(
{
"input_ids": tf.TensorShape([None]),
"attention_mask": tf.TensorShape([None]),
"token_type_ids": tf.TensorShape([None]),
},
tf.TensorShape([]),
),
)
DATA_COLUMN = 'sentence'
LABEL_COLUMN = 'polarity'
# train and test is your dataset
train_InputExamples, validation_InputExamples = convert_data_to_examples(train, test, DATA_COLUMN, LABEL_COLUMN)
train_data = convert_examples_to_tf_dataset(list(train_InputExamples), tokenizer)
train_data = train_data.shuffle(100).batch(32).repeat(2)
validation_data = convert_examples_to_tf_dataset(list(validation_InputExamples), tokenizer)
validation_data = validation_data.batch(32)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment