Last active
September 3, 2023 14:02
-
-
Save RayWilliam46/c2cdc2e41bef33b332151d7acc2afef2 to your computer and use it in GitHub Desktop.
Batch encodes text data using a Hugging Face tokenizer
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Define the maximum number of words to tokenize (DistilBERT can tokenize up to 512) | |
MAX_LENGTH = 128 | |
# Define function to encode text data in batches | |
def batch_encode(tokenizer, texts, batch_size=256, max_length=MAX_LENGTH): | |
""""""""" | |
A function that encodes a batch of texts and returns the texts' | |
corresponding encodings and attention masks that are ready to be fed | |
into a pre-trained transformer model. | |
Input: | |
- tokenizer: Tokenizer object from the PreTrainedTokenizer Class | |
- texts: List of strings where each string represents a text | |
- batch_size: Integer controlling number of texts in a batch | |
- max_length: Integer controlling max number of words to tokenize in a given text | |
Output: | |
- input_ids: sequence of texts encoded as a tf.Tensor object | |
- attention_mask: the texts' attention mask encoded as a tf.Tensor object | |
""""""""" | |
input_ids = [] | |
attention_mask = [] | |
for i in range(0, len(texts), batch_size): | |
batch = texts[i:i+batch_size] | |
inputs = tokenizer.batch_encode_plus(batch, | |
max_length=max_length, | |
padding='longest', #implements dynamic padding | |
truncation=True, | |
return_attention_mask=True, | |
return_token_type_ids=False | |
) | |
input_ids.extend(inputs['input_ids']) | |
attention_mask.extend(inputs['attention_mask']) | |
return tf.convert_to_tensor(input_ids), tf.convert_to_tensor(attention_mask) | |
# Encode X_train | |
X_train_ids, X_train_attention = batch_encode(tokenizer, X_train.tolist()) | |
# Encode X_valid | |
X_valid_ids, X_valid_attention = batch_encode(tokenizer, X_valid.tolist()) | |
# Encode X_test | |
X_test_ids, X_test_attention = batch_encode(tokenizer, X_test.tolist()) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment