Skip to content

Instantly share code, notes, and snippets.

@RayWilliam46
Last active September 3, 2023 14:02
Show Gist options
  • Save RayWilliam46/c2cdc2e41bef33b332151d7acc2afef2 to your computer and use it in GitHub Desktop.
Save RayWilliam46/c2cdc2e41bef33b332151d7acc2afef2 to your computer and use it in GitHub Desktop.
Batch encodes text data using a Hugging Face tokenizer
# Define the maximum number of words to tokenize (DistilBERT can tokenize up to 512)
MAX_LENGTH = 128
# Define function to encode text data in batches
def batch_encode(tokenizer, texts, batch_size=256, max_length=MAX_LENGTH):
"""""""""
A function that encodes a batch of texts and returns the texts'
corresponding encodings and attention masks that are ready to be fed
into a pre-trained transformer model.
Input:
- tokenizer: Tokenizer object from the PreTrainedTokenizer Class
- texts: List of strings where each string represents a text
- batch_size: Integer controlling number of texts in a batch
- max_length: Integer controlling max number of words to tokenize in a given text
Output:
- input_ids: sequence of texts encoded as a tf.Tensor object
- attention_mask: the texts' attention mask encoded as a tf.Tensor object
"""""""""
input_ids = []
attention_mask = []
for i in range(0, len(texts), batch_size):
batch = texts[i:i+batch_size]
inputs = tokenizer.batch_encode_plus(batch,
max_length=max_length,
padding='longest', #implements dynamic padding
truncation=True,
return_attention_mask=True,
return_token_type_ids=False
)
input_ids.extend(inputs['input_ids'])
attention_mask.extend(inputs['attention_mask'])
return tf.convert_to_tensor(input_ids), tf.convert_to_tensor(attention_mask)
# Encode X_train
X_train_ids, X_train_attention = batch_encode(tokenizer, X_train.tolist())
# Encode X_valid
X_valid_ids, X_valid_attention = batch_encode(tokenizer, X_valid.tolist())
# Encode X_test
X_test_ids, X_test_attention = batch_encode(tokenizer, X_test.tolist())
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment