Skip to content

Instantly share code, notes, and snippets.

@a7v8x
Created May 5, 2020 19:01
Show Gist options
  • Save a7v8x/32544f0a452e092b2c54705c810b8eb0 to your computer and use it in GitHub Desktop.
Save a7v8x/32544f0a452e092b2c54705c810b8eb0 to your computer and use it in GitHub Desktop.
# map to the expected input to TFBertForSequenceClassification, see here
def map_example_to_dict(input_ids, attention_masks, token_type_ids, label):
return {
"input_ids": input_ids,
"token_type_ids": token_type_ids,
"attention_mask": attention_masks,
}, label
def encode_examples(ds, limit=-1):
# prepare list, so that we can build up final TensorFlow dataset from slices.
input_ids_list = []
token_type_ids_list = []
attention_mask_list = []
label_list = []
if (limit > 0):
ds = ds.take(limit)
for review, label in tfds.as_numpy(ds):
bert_input = convert_example_to_feature(review.decode())
input_ids_list.append(bert_input['input_ids'])
token_type_ids_list.append(bert_input['token_type_ids'])
attention_mask_list.append(bert_input['attention_mask'])
label_list.append([label])
return tf.data.Dataset.from_tensor_slices((input_ids_list, attention_mask_list, token_type_ids_list, label_list)).map(map_example_to_dict)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment